LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
TritonData.h
Go to the documentation of this file.
1 #ifndef NuSonic_Triton_TritonData
2 #define NuSonic_Triton_TritonData
3 
4 #include "cetlib_except/exception.h"
7 
8 #include <algorithm>
9 #include <any>
10 #include <memory>
11 #include <numeric>
12 #include <string>
13 #include <unordered_map>
14 #include <vector>
15 
16 #include "grpc_client.h"
17 #include "triton/common/model_config.h"
18 
19 namespace nic = triton::client;
20 
21 namespace lartriton {
22 
23  //forward declaration
24  class TritonClient;
25 
26  //aliases for local input and output types
27  template <typename DT>
28  using TritonInput = std::vector<std::vector<DT>>;
29  template <typename DT>
30  using TritonOutput = std::vector<triton_span::Span<const DT*>>;
31 
32  //store all the info needed for triton input and output
33  template <typename IO>
34  class TritonData {
35  public:
36  using Result = nic::InferResult;
37  using TensorMetadata = inference::ModelMetadataResponse_TensorMetadata;
38  using ShapeType = std::vector<int64_t>;
40 
41  //constructor
42  TritonData(const std::string& name, const TensorMetadata& model_info, bool noBatch);
43 
44  //some members can be modified
45  bool setShape(const ShapeType& newShape) { return setShape(newShape, true); }
46  bool setShape(unsigned loc, int64_t val) { return setShape(loc, val, true); }
47 
48  //io accessors
49  template <typename DT>
50  void toServer(std::shared_ptr<TritonInput<DT>> ptr)
51  {
52  const auto& data_in = *ptr;
53 
54  //check batch size
55  if (data_in.size() != batchSize_) {
56  throw cet::exception("TritonDataError")
57  << name_ << " input(): input vector has size " << data_in.size()
58  << " but specified batch size is " << batchSize_;
59  }
60 
61  //shape must be specified for variable dims or if batch size changes
62  data_->SetShape(fullShape_);
63 
64  if (byteSize_ != sizeof(DT))
65  throw cet::exception("TritonDataError")
66  << name_ << " input(): inconsistent byte size " << sizeof(DT) << " (should be "
67  << byteSize_ << " for " << dname_ << ")";
68 
69  for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
70  const DT* arr = data_in[i0].data();
72  data_->AppendRaw(reinterpret_cast<const uint8_t*>(arr), data_in[i0].size() * byteSize_),
73  name_ + " input(): unable to set data for batch entry " + std::to_string(i0));
74  }
75 
76  //keep input data in scope
77  holder_ = std::move(ptr);
78  }
79 
80  template <typename DT>
82 
83  //const accessors
84  const ShapeView& shape() const { return shape_; }
85  int64_t byteSize() const { return byteSize_; }
86  const std::string& dname() const { return dname_; }
87  unsigned batchSize() const { return batchSize_; }
88 
89  //utilities
90  bool variableDims() const { return variableDims_; }
91  int64_t sizeDims() const { return productDims_; }
92  //default to dims if shape isn't filled
93  int64_t sizeShape() const { return variableDims_ ? dimProduct(shape_) : sizeDims(); }
94 
95  private:
96  friend class TritonClient;
97 
98  //private accessors only used by client
99  bool setShape(const ShapeType& newShape, bool canThrow);
100  bool setShape(unsigned loc, int64_t val, bool canThrow);
101  void setBatchSize(unsigned bsize);
102  void reset();
103  void setResult(std::shared_ptr<Result> result) { result_ = result; }
104  IO* data() { return data_.get(); }
105 
106  //helpers
107  bool anyNeg(const ShapeView& vec) const
108  {
109  return std::any_of(vec.begin(), vec.end(), [](int64_t i) { return i < 0; });
110  }
111  int64_t dimProduct(const ShapeView& vec) const
112  {
113  return std::accumulate(vec.begin(), vec.end(), 1, std::multiplies<int64_t>());
114  }
115  void createObject(IO** ioptr) const;
116 
117  //members
118  std::string name_;
119  std::shared_ptr<IO> data_;
121  bool noBatch_;
122  unsigned batchSize_;
126  int64_t productDims_;
127  std::string dname_;
129  int64_t byteSize_;
130  std::any holder_;
131  std::shared_ptr<Result> result_;
132  };
133 
135  using TritonInputMap = std::unordered_map<std::string, TritonInputData>;
137  using TritonOutputMap = std::unordered_map<std::string, TritonOutputData>;
138 
139  template <>
140  void TritonInputData::reset();
141  template <>
143  template <>
144  void TritonInputData::createObject(nic::InferInput** ioptr) const;
145  template <>
146  void TritonOutputData::createObject(nic::InferRequestedOutput** ioptr) const;
147 
148  //explicit template instantiation declarations
149  extern template class TritonData<nic::InferInput>;
150  extern template class TritonData<nic::InferRequestedOutput>;
151 
152 }
153 #endif
std::unordered_map< std::string, TritonOutputData > TritonOutputMap
Definition: TritonData.h:137
void setBatchSize(unsigned bsize)
Definition: TritonData.cc:100
bool setShape(const ShapeType &newShape)
Definition: TritonData.h:45
TritonData(const std::string &name, const TensorMetadata &model_info, bool noBatch)
Definition: TritonData.cc:17
const ShapeType dims_
Definition: TritonData.h:120
bool anyNeg(const ShapeView &vec) const
Definition: TritonData.h:107
int64_t sizeDims() const
Definition: TritonData.h:91
bool variableDims() const
Definition: TritonData.h:90
const ShapeView & shape() const
Definition: TritonData.h:84
std::vector< int64_t > ShapeType
Definition: TritonData.h:38
T begin() const
Definition: Span.h:20
inference::DataType dtype_
Definition: TritonData.h:128
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:26
void setResult(std::shared_ptr< Result > result)
Definition: TritonData.h:103
void toServer(std::shared_ptr< TritonInput< DT >> ptr)
Definition: TritonData.h:50
T end() const
Definition: Span.h:21
decltype(auto) constexpr to_string(T &&obj)
ADL-aware version of std::to_string.
std::vector< std::vector< DT >> TritonInput
Definition: TritonData.h:28
unsigned batchSize() const
Definition: TritonData.h:87
int64_t sizeShape() const
Definition: TritonData.h:93
void createObject(IO **ioptr) const
bool setShape(unsigned loc, int64_t val)
Definition: TritonData.h:46
std::shared_ptr< Result > result_
Definition: TritonData.h:131
std::vector< triton_span::Span< const DT * >> TritonOutput
Definition: TritonData.h:30
std::shared_ptr< IO > data_
Definition: TritonData.h:119
int64_t dimProduct(const ShapeView &vec) const
Definition: TritonData.h:111
const std::string & dname() const
Definition: TritonData.h:86
std::unordered_map< std::string, TritonInputData > TritonInputMap
Definition: TritonData.h:135
int64_t byteSize() const
Definition: TritonData.h:85
inference::ModelMetadataResponse_TensorMetadata TensorMetadata
Definition: TritonData.h:37
std::string dname_
Definition: TritonData.h:127
TritonOutput< DT > fromServer() const
Definition: TritonData.cc:110
nic::InferResult Result
Definition: TritonData.h:36
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33