LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
PointIdAlgTriton_tool.cc
Go to the documentation of this file.
1 // Class: PointIdAlgTriton_tool
3 // Authors: M.Wang, FNAL, 2021: Nvidia Triton inf client
5 
7 #include "cetlib_except/exception.h"
9 #include "fhiclcpp/types/Table.h"
10 #include "larrecodnn/ImagePatternAlgs/ToolInterfaces/IPointIdAlg.h"
11 
12 // Nvidia Triton inference server client includes
13 #include "grpc_client.h"
14 
15 namespace nic = triton::client;
16 
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 namespace PointIdAlgTools {
23 
24  class PointIdAlgTriton : public IPointIdAlg {
25  public:
26  explicit PointIdAlgTriton(fhicl::Table<Config> const& table);
27 
28  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) const override;
29  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
30  int samples = -1) const override;
31 
32  private:
33  std::string fTritonModelName;
34  std::string fTritonURL;
36  std::string fTritonModelVersion;
37 
38  std::unique_ptr<nic::InferenceServerGrpcClient> triton_client;
39  inference::ModelMetadataResponse triton_modmet;
40  inference::ModelConfigResponse triton_modcfg;
41  mutable std::vector<int64_t> triton_inpshape;
42  nic::InferOptions triton_options;
43  };
44 
45  // ------------------------------------------------------
47  : img::DataProviderAlg(table()), triton_options("")
48  {
49  // ... Get common config vars
50  fNNetOutputs = table().NNetOutputs();
51  fPatchSizeW = table().PatchSizeW();
52  fPatchSizeD = table().PatchSizeD();
53  fCurrentWireIdx = 99999;
54  fCurrentScaledDrift = 99999;
55 
56  // ... Get "optional" config vars specific to Triton interface
57  fTritonModelName = table().TritonModelName();
58  fTritonURL = table().TritonURL();
59  fTritonVerbose = table().TritonVerbose();
60  fTritonModelVersion = table().TritonModelVersion();
61 
62  // ... Create the Triton inference client
63  auto err = nic::InferenceServerGrpcClient::Create(&triton_client, fTritonURL, fTritonVerbose);
64  if (!err.IsOk()) {
65  throw cet::exception("PointIdAlgTriton")
66  << "error: unable to create client for inference: " << err << std::endl;
67  }
68 
69  // ... Get the model metadata and config information
71  if (!err.IsOk()) {
72  throw cet::exception("PointIdAlgTriton")
73  << "error: failed to get model metadata: " << err << std::endl;
74  }
76  if (!err.IsOk()) {
77  throw cet::exception("PointIdAlgTriton")
78  << "error: failed to get model config: " << err << std::endl;
79  }
80 
81  // ... Set up shape vector needed when creating inference input
82  triton_inpshape.push_back(1); // initialize batch_size to 1
83  triton_inpshape.push_back(triton_modmet.inputs(0).shape(1));
84  triton_inpshape.push_back(triton_modmet.inputs(0).shape(2));
85  triton_inpshape.push_back(triton_modmet.inputs(0).shape(3));
86 
87  // ... Set up Triton inference client options
88  triton_options.model_name_ = fTritonModelName;
89  triton_options.model_version_ = fTritonModelVersion;
90 
91  mf::LogInfo("PointIdAlgTriton") << "url: " << fTritonURL;
92  mf::LogInfo("PointIdAlgTriton") << "model name: " << fTritonModelName;
93  mf::LogInfo("PointIdAlgTriton") << "model version: " << fTritonModelVersion;
94  mf::LogInfo("PointIdAlgTriton") << "verbose: " << fTritonVerbose;
95 
96  mf::LogInfo("PointIdAlgTriton") << "tensorRT inference context created.";
97 
98  resizePatch();
99  }
100 
101  // ------------------------------------------------------
102  std::vector<float> PointIdAlgTriton::Run(std::vector<std::vector<float>> const& inp2d) const
103  {
104  size_t nrows = inp2d.size(), ncols = inp2d.front().size();
105 
106  triton_inpshape.at(0) = 1; // set batch size
107 
108  // ~~~~ Initialize the inputs
109 
110  nic::InferInput* triton_input;
111  auto err = nic::InferInput::Create(&triton_input,
112  triton_modmet.inputs(0).name(),
114  triton_modmet.inputs(0).datatype());
115  if (!err.IsOk()) {
116  throw cet::exception("PointIdAlgTriton") << "unable to get input: " << err << std::endl;
117  }
118  std::shared_ptr<nic::InferInput> triton_input_ptr(triton_input);
119  std::vector<nic::InferInput*> triton_inputs = {triton_input_ptr.get()};
120 
121  // ~~~~ Register the mem address of 1st byte of image and #bytes in image
122 
123  err = triton_input_ptr->Reset();
124  if (!err.IsOk()) {
125  throw cet::exception("PointIdAlgTriton")
126  << "failed resetting Triton model input: " << err << std::endl;
127  }
128 
129  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
130  std::vector<float> fa(sbuff_byte_size);
131 
132  // ..flatten the 2d array into contiguous 1d block
133  for (size_t ir = 0; ir < nrows; ++ir) {
134  std::copy(inp2d[ir].begin(), inp2d[ir].end(), fa.begin() + (ir * ncols));
135  }
136  err = triton_input_ptr->AppendRaw(reinterpret_cast<uint8_t*>(fa.data()), sbuff_byte_size);
137  if (!err.IsOk()) {
138  throw cet::exception("PointIdAlgTriton")
139  << "failed setting Triton input: " << err << std::endl;
140  }
141 
142  // ~~~~ Send inference request
143 
144  nic::InferResult* results;
145 
146  err = triton_client->Infer(&results, triton_options, triton_inputs);
147  if (!err.IsOk()) {
148  throw cet::exception("PointIdAlgTriton")
149  << "failed sending Triton synchronous infer request: " << err << std::endl;
150  }
151  std::shared_ptr<nic::InferResult> results_ptr;
152  results_ptr.reset(results);
153 
154  // ~~~~ Retrieve inference results
155 
156  std::vector<float> out;
157 
158  const float* prb0;
159  size_t rbuff0_byte_size; // size of result buffer in bytes
160  results_ptr->RawData(
161  triton_modmet.outputs(0).name(), (const uint8_t**)&prb0, &rbuff0_byte_size);
162  size_t ncat0 = rbuff0_byte_size / sizeof(float);
163 
164  const float* prb1;
165  size_t rbuff1_byte_size; // size of result buffer in bytes
166  results_ptr->RawData(
167  triton_modmet.outputs(1).name(), (const uint8_t**)&prb1, &rbuff1_byte_size);
168  size_t ncat1 = rbuff1_byte_size / sizeof(float);
169 
170  for (unsigned j = 0; j < ncat0; j++)
171  out.push_back(*(prb0 + j));
172  for (unsigned j = 0; j < ncat1; j++)
173  out.push_back(*(prb1 + j));
174 
175  return out;
176  }
177 
178  // ------------------------------------------------------
179  std::vector<std::vector<float>> PointIdAlgTriton::Run(
180  std::vector<std::vector<std::vector<float>>> const& inps,
181  int samples) const
182  {
183  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
184  return std::vector<std::vector<float>>();
185  }
186 
187  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
188 
189  size_t usamples = samples;
190  size_t nrows = inps.front().size(), ncols = inps.front().front().size();
191 
192  triton_inpshape.at(0) = usamples; // set batch size
193 
194  // ~~~~ Initialize the inputs
195 
196  nic::InferInput* triton_input;
197  auto err = nic::InferInput::Create(&triton_input,
198  triton_modmet.inputs(0).name(),
200  triton_modmet.inputs(0).datatype());
201  if (!err.IsOk()) {
202  throw cet::exception("PointIdAlgTriton") << "unable to get input: " << err << std::endl;
203  }
204  std::shared_ptr<nic::InferInput> triton_input_ptr(triton_input);
205  std::vector<nic::InferInput*> triton_inputs = {triton_input_ptr.get()};
206 
207  // ~~~~ For each sample, register the mem address of 1st byte of image and #bytes in image
208  err = triton_input_ptr->Reset();
209  if (!err.IsOk()) {
210  throw cet::exception("PointIdAlgTriton")
211  << "failed resetting Triton model input: " << err << std::endl;
212  }
213 
214  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
215  std::vector<std::vector<float>> fa(usamples, std::vector<float>(sbuff_byte_size));
216 
217  for (size_t idx = 0; idx < usamples; ++idx) {
218  // ..first flatten the 2d array into contiguous 1d block
219  for (size_t ir = 0; ir < nrows; ++ir) {
220  std::copy(inps[idx][ir].begin(), inps[idx][ir].end(), fa[idx].begin() + (ir * ncols));
221  }
222  err =
223  triton_input_ptr->AppendRaw(reinterpret_cast<uint8_t*>(fa[idx].data()), sbuff_byte_size);
224  if (!err.IsOk()) {
225  throw cet::exception("PointIdAlgTriton")
226  << "failed setting Triton input: " << err << std::endl;
227  }
228  }
229 
230  // ~~~~ Send inference request
231 
232  nic::InferResult* results;
233 
234  err = triton_client->Infer(&results, triton_options, triton_inputs);
235  if (!err.IsOk()) {
236  throw cet::exception("PointIdAlgTriton")
237  << "failed sending Triton synchronous infer request: " << err << std::endl;
238  }
239  std::shared_ptr<nic::InferResult> results_ptr;
240  results_ptr.reset(results);
241 
242  // ~~~~ Retrieve inference results
243 
244  std::vector<std::vector<float>> out;
245 
246  const float* prb0;
247  size_t rbuff0_byte_size; // size of result buffer in bytes
248  results_ptr->RawData(
249  triton_modmet.outputs(0).name(), (const uint8_t**)&prb0, &rbuff0_byte_size);
250  size_t ncat0 = rbuff0_byte_size / (usamples * sizeof(float));
251 
252  const float* prb1;
253  size_t rbuff1_byte_size; // size of result buffer in bytes
254  results_ptr->RawData(
255  triton_modmet.outputs(1).name(), (const uint8_t**)&prb1, &rbuff1_byte_size);
256  size_t ncat1 = rbuff1_byte_size / (usamples * sizeof(float));
257 
258  for (unsigned i = 0; i < usamples; i++) {
259  std::vector<float> vprb;
260  for (unsigned j = 0; j < ncat0; j++)
261  vprb.push_back(*(prb0 + i * ncat0 + j));
262  for (unsigned j = 0; j < ncat1; j++)
263  vprb.push_back(*(prb1 + i * ncat1 + j));
264  out.push_back(vprb);
265  }
266 
267  return out;
268  }
269 
270 }
PointIdAlgTriton(fhicl::Table< Config > const &table)
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
std::vector< float > Run(std::vector< std::vector< float >> const &inp2d) const override
inference::ModelMetadataResponse triton_modmet
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
inference::ModelConfigResponse triton_modcfg
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
Int_t ncols
Definition: plot.C:52
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
TFile fa("Li7.root")