LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
TritonData.cc
Go to the documentation of this file.
2 #include "cetlib_except/exception.h"
5 
6 #include <cstring>
7 #include <sstream>
8 
9 namespace ni = triton::common;
10 namespace nic = triton::client;
11 
12 namespace lartriton {
13 
14  //dims: kept constant, represents config.pbtxt parameters of model (converted from google::protobuf::RepeatedField to vector)
15  //fullShape: if batching is enabled, first entry is batch size; values can be modified
16  //shape: view into fullShape, excluding batch size entry
17  template <typename IO>
18  TritonData<IO>::TritonData(const std::string& name,
19  const TritonData<IO>::TensorMetadata& model_info,
20  bool noBatch)
21  : name_(name)
22  , dims_(model_info.shape().begin(), model_info.shape().end())
23  , noBatch_(noBatch)
24  , batchSize_(0)
25  , fullShape_(dims_)
26  , shape_(fullShape_.begin() + (noBatch_ ? 0 : 1), fullShape_.end())
27  , variableDims_(anyNeg(shape_))
28  , productDims_(variableDims_ ? -1 : dimProduct(shape_))
29  , dname_(model_info.datatype())
30  , dtype_(ni::ProtocolStringToDataType(dname_))
31  , byteSize_(ni::GetDataTypeByteSize(dtype_))
32  {
33  //create input or output object
34  IO* iotmp;
35  createObject(&iotmp);
36  data_.reset(iotmp);
37  }
38 
39  template <>
40  void TritonInputData::createObject(nic::InferInput** ioptr) const
41  {
42  nic::InferInput::Create(ioptr, name_, fullShape_, dname_);
43  }
44 
45  template <>
46  void TritonOutputData::createObject(nic::InferRequestedOutput** ioptr) const
47  {
48  nic::InferRequestedOutput::Create(ioptr, name_);
49  }
50 
51  //setters
52  template <typename IO>
53  bool TritonData<IO>::setShape(const TritonData<IO>::ShapeType& newShape, bool canThrow)
54  {
55  bool result = true;
56  for (unsigned i = 0; i < newShape.size(); ++i) {
57  result &= setShape(i, newShape[i], canThrow);
58  }
59  return result;
60  }
61 
62  template <typename IO>
63  bool TritonData<IO>::setShape(unsigned loc, int64_t val, bool canThrow)
64  {
65  std::stringstream msg;
66  unsigned full_loc = loc + (noBatch_ ? 0 : 1);
67 
68  //check boundary
69  if (full_loc >= fullShape_.size()) {
70  msg << name_ << " setShape(): dimension " << full_loc << " out of bounds ("
71  << fullShape_.size() << ")";
72  if (canThrow)
73  throw cet::exception("TritonDataError") << msg.str();
74  else {
75  MF_LOG_WARNING("TritonDataWarning") << msg.str();
76  return false;
77  }
78  }
79 
80  if (val != fullShape_[full_loc]) {
81  if (dims_[full_loc] == -1) {
82  fullShape_[full_loc] = val;
83  return true;
84  }
85  else {
86  msg << name_ << " setShape(): attempt to change value of non-variable shape dimension "
87  << loc;
88  if (canThrow)
89  throw cet::exception("TritonDataError") << msg.str();
90  else {
91  MF_LOG_WARNING("TritonDataError") << msg.str();
92  return false;
93  }
94  }
95  }
96 
97  return true;
98  }
99 
100  template <typename IO>
101  void TritonData<IO>::setBatchSize(unsigned bsize)
102  {
103  batchSize_ = bsize;
104  if (!noBatch_) fullShape_[0] = batchSize_;
105  }
106 
107  //io accessors
108  template <>
109  template <typename DT>
110  void TritonInputData::toServer(std::shared_ptr<TritonInput<DT>> ptr)
111  {
112  const auto& data_in = *ptr;
113 
114  //check batch size
115  if (data_in.size() != batchSize_) {
116  throw cet::exception("TritonDataError")
117  << name_ << " input(): input vector has size " << data_in.size()
118  << " but specified batch size is " << batchSize_;
119  }
120 
121  //shape must be specified for variable dims or if batch size changes
122  data_->SetShape(fullShape_);
123 
124  if (byteSize_ != sizeof(DT))
125  throw cet::exception("TritonDataError")
126  << name_ << " input(): inconsistent byte size " << sizeof(DT) << " (should be " << byteSize_
127  << " for " << dname_ << ")";
128 
129  int64_t nInput = sizeShape();
130  for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
131  const DT* arr = data_in[i0].data();
133  data_->AppendRaw(reinterpret_cast<const uint8_t*>(arr), nInput * byteSize_),
134  name_ + " input(): unable to set data for batch entry " + std::to_string(i0));
135  }
136 
137  //keep input data in scope
138  holder_ = std::move(ptr);
139  }
140 
141  template <>
142  template <typename DT>
144  {
145  if (!result_) {
146  throw cet::exception("TritonDataError") << name_ << " output(): missing result";
147  }
148 
149  if (byteSize_ != sizeof(DT)) {
150  throw cet::exception("TritonDataError")
151  << name_ << " output(): inconsistent byte size " << sizeof(DT) << " (should be "
152  << byteSize_ << " for " << dname_ << ")";
153  }
154 
155  uint64_t nOutput = sizeShape();
156  TritonOutput<DT> dataOut;
157  const uint8_t* r0;
158  size_t contentByteSize;
159  size_t expectedContentByteSize = nOutput * byteSize_ * batchSize_;
160  triton_utils::throwIfError(result_->RawData(name_, &r0, &contentByteSize),
161  "output(): unable to get raw");
162  if (contentByteSize != expectedContentByteSize) {
163  throw cet::exception("TritonDataError")
164  << name_ << " output(): unexpected content byte size " << contentByteSize << " (expected "
165  << expectedContentByteSize << ")";
166  }
167 
168  const DT* r1 = reinterpret_cast<const DT*>(r0);
169  dataOut.reserve(batchSize_);
170  for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
171  auto offset = i0 * nOutput;
172  dataOut.emplace_back(r1 + offset, r1 + offset + nOutput);
173  }
174 
175  return dataOut;
176  }
177 
178  template <>
180  {
181  data_->Reset();
182  holder_.reset();
183  }
184 
185  template <>
187  {
188  result_.reset();
189  }
190 
191  //explicit template instantiation declarations
192  template class TritonData<nic::InferInput>;
194 
195  template void TritonInputData::toServer(std::shared_ptr<TritonInput<float>> data_in);
196  template void TritonInputData::toServer(std::shared_ptr<TritonInput<int64_t>> data_in);
197 
199 
200 }
std::string name_
Definition: TritonData.h:87
void setBatchSize(unsigned bsize)
Definition: TritonData.cc:101
bool setShape(const ShapeType &newShape)
Definition: TritonData.h:43
TritonData(const std::string &name, const TensorMetadata &model_info, bool noBatch)
Definition: TritonData.cc:18
const ShapeType dims_
Definition: TritonData.h:89
ShapeType fullShape_
Definition: TritonData.h:92
std::vector< int64_t > ShapeType
Definition: TritonData.h:36
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:26
void toServer(std::shared_ptr< TritonInput< DT >> ptr)
Definition: TritonData.cc:110
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
decltype(auto) constexpr to_string(T &&obj)
ADL-aware version of std::to_string.
std::vector< std::vector< DT >> TritonInput
Definition: TritonData.h:26
int64_t sizeShape() const
Definition: TritonData.h:62
void createObject(IO **ioptr) const
std::shared_ptr< Result > result_
Definition: TritonData.h:100
std::vector< triton_span::Span< const DT * >> TritonOutput
Definition: TritonData.h:28
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
std::shared_ptr< IO > data_
Definition: TritonData.h:88
inference::ModelMetadataResponse_TensorMetadata TensorMetadata
Definition: TritonData.h:35
#define MF_LOG_WARNING(category)
std::string dname_
Definition: TritonData.h:96
TritonOutput< DT > fromServer() const
Definition: TritonData.cc:143
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33