LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
TritonData.cc
Go to the documentation of this file.
4 
5 #include <cstring>
6 #include <sstream>
7 
8 namespace ni = triton::common;
9 namespace nic = triton::client;
10 
11 namespace lartriton {
12 
13  //dims: kept constant, represents config.pbtxt parameters of model (converted from google::protobuf::RepeatedField to vector)
14  //fullShape: if batching is enabled, first entry is batch size; values can be modified
15  //shape: view into fullShape, excluding batch size entry
16  template <typename IO>
17  TritonData<IO>::TritonData(const std::string& name,
18  const TritonData<IO>::TensorMetadata& model_info,
19  bool noBatch)
20  : name_(name)
21  , dims_(model_info.shape().begin(), model_info.shape().end())
22  , noBatch_(noBatch)
23  , batchSize_(0)
24  , fullShape_(dims_)
25  , shape_(fullShape_.begin() + (noBatch_ ? 0 : 1), fullShape_.end())
26  , variableDims_(anyNeg(shape_))
27  , productDims_(variableDims_ ? -1 : dimProduct(shape_))
28  , dname_(model_info.datatype())
29  , dtype_(ni::ProtocolStringToDataType(dname_))
30  , byteSize_(ni::GetDataTypeByteSize(dtype_))
31  {
32  //create input or output object
33  IO* iotmp;
34  createObject(&iotmp);
35  data_.reset(iotmp);
36  }
37 
38  template <>
39  void TritonInputData::createObject(nic::InferInput** ioptr) const
40  {
41  nic::InferInput::Create(ioptr, name_, fullShape_, dname_);
42  }
43 
44  template <>
45  void TritonOutputData::createObject(nic::InferRequestedOutput** ioptr) const
46  {
47  nic::InferRequestedOutput::Create(ioptr, name_);
48  }
49 
50  //setters
51  template <typename IO>
52  bool TritonData<IO>::setShape(const TritonData<IO>::ShapeType& newShape, bool canThrow)
53  {
54  bool result = true;
55  for (unsigned i = 0; i < newShape.size(); ++i) {
56  result &= setShape(i, newShape[i], canThrow);
57  }
58  return result;
59  }
60 
61  template <typename IO>
62  bool TritonData<IO>::setShape(unsigned loc, int64_t val, bool canThrow)
63  {
64  std::stringstream msg;
65  unsigned full_loc = loc + (noBatch_ ? 0 : 1);
66 
67  //check boundary
68  if (full_loc >= fullShape_.size()) {
69  msg << name_ << " setShape(): dimension " << full_loc << " out of bounds ("
70  << fullShape_.size() << ")";
71  if (canThrow)
72  throw cet::exception("TritonDataError") << msg.str();
73  else {
74  MF_LOG_WARNING("TritonDataWarning") << msg.str();
75  return false;
76  }
77  }
78 
79  if (val != fullShape_[full_loc]) {
80  if (dims_[full_loc] == -1) {
81  fullShape_[full_loc] = val;
82  return true;
83  }
84  else {
85  msg << name_ << " setShape(): attempt to change value of non-variable shape dimension "
86  << loc;
87  if (canThrow)
88  throw cet::exception("TritonDataError") << msg.str();
89  else {
90  MF_LOG_WARNING("TritonDataError") << msg.str();
91  return false;
92  }
93  }
94  }
95 
96  return true;
97  }
98 
99  template <typename IO>
100  void TritonData<IO>::setBatchSize(unsigned bsize)
101  {
102  batchSize_ = bsize;
103  if (!noBatch_) fullShape_[0] = batchSize_;
104  }
105 
106  //io accessors
107 
108  template <>
109  template <typename DT>
111  {
112  if (!result_) {
113  throw cet::exception("TritonDataError") << name_ << " output(): missing result";
114  }
115 
116  if (byteSize_ != sizeof(DT)) {
117  throw cet::exception("TritonDataError")
118  << name_ << " output(): inconsistent byte size " << sizeof(DT) << " (should be "
119  << byteSize_ << " for " << dname_ << ")";
120  }
121 
122  uint64_t nOutput = sizeShape();
123  TritonOutput<DT> dataOut;
124  const uint8_t* r0;
125  size_t contentByteSize;
126  size_t expectedContentByteSize = nOutput * byteSize_ * batchSize_;
127  triton_utils::throwIfError(result_->RawData(name_, &r0, &contentByteSize),
128  "output(): unable to get raw");
129  if (contentByteSize != expectedContentByteSize) {
130  throw cet::exception("TritonDataError")
131  << name_ << " output(): unexpected content byte size " << contentByteSize << " (expected "
132  << expectedContentByteSize << ")";
133  }
134 
135  const DT* r1 = reinterpret_cast<const DT*>(r0);
136  dataOut.reserve(batchSize_);
137  for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
138  auto offset = i0 * nOutput;
139  dataOut.emplace_back(r1 + offset, r1 + offset + nOutput);
140  }
141 
142  return dataOut;
143  }
144 
145  template <>
147  {
148  data_->Reset();
149  holder_.reset();
150  }
151 
152  template <>
154  {
155  result_.reset();
156  }
157 
158  //explicit template instantiation declarations
159  template class TritonData<nic::InferInput>;
161 
162  template void TritonInputData::toServer(std::shared_ptr<TritonInput<float>> data_in);
163  template void TritonInputData::toServer(std::shared_ptr<TritonInput<int64_t>> data_in);
164 
166 
167 }
void setBatchSize(unsigned bsize)
Definition: TritonData.cc:100
bool setShape(const ShapeType &newShape)
Definition: TritonData.h:45
TritonData(const std::string &name, const TensorMetadata &model_info, bool noBatch)
Definition: TritonData.cc:17
const ShapeType dims_
Definition: TritonData.h:120
std::vector< int64_t > ShapeType
Definition: TritonData.h:38
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:26
void toServer(std::shared_ptr< TritonInput< DT >> ptr)
Definition: TritonData.h:50
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
std::vector< std::vector< DT >> TritonInput
Definition: TritonData.h:28
int64_t sizeShape() const
Definition: TritonData.h:93
void createObject(IO **ioptr) const
std::shared_ptr< Result > result_
Definition: TritonData.h:131
std::vector< triton_span::Span< const DT * >> TritonOutput
Definition: TritonData.h:30
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
std::shared_ptr< IO > data_
Definition: TritonData.h:119
inference::ModelMetadataResponse_TensorMetadata TensorMetadata
Definition: TritonData.h:37
#define MF_LOG_WARNING(category)
std::string dname_
Definition: TritonData.h:127
TritonOutput< DT > fromServer() const
Definition: TritonData.cc:110
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33