LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
PointIdAlgTools::PointIdAlgTriton Class Reference
Inheritance diagram for PointIdAlgTools::PointIdAlgTriton:

Public Member Functions

 PointIdAlgTriton (fhicl::Table< Config > const &table)
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) const override
 
std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) const override
 

Private Attributes

std::string fTritonModelName
 
std::string fTritonURL
 
bool fTritonVerbose
 
std::string fTritonModelVersion
 
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
 
inference::ModelMetadataResponse triton_modmet
 
inference::ModelConfigResponse triton_modcfg
 
std::vector< int64_t > triton_inpshape
 
nic::InferOptions triton_options
 

Detailed Description

Definition at line 24 of file PointIdAlgTriton_tool.cc.

Constructor & Destructor Documentation

PointIdAlgTools::PointIdAlgTriton::PointIdAlgTriton ( fhicl::Table< Config > const &  table)
explicit

Definition at line 46 of file PointIdAlgTriton_tool.cc.

References fTritonModelName, fTritonModelVersion, fTritonURL, fTritonVerbose, triton_client, triton_inpshape, triton_modcfg, triton_modmet, and triton_options.

47  : img::DataProviderAlg(table()), triton_options("")
48  {
49  // ... Get common config vars
50  fNNetOutputs = table().NNetOutputs();
51  fPatchSizeW = table().PatchSizeW();
52  fPatchSizeD = table().PatchSizeD();
53  fCurrentWireIdx = 99999;
54  fCurrentScaledDrift = 99999;
55 
56  // ... Get "optional" config vars specific to Triton interface
57  fTritonModelName = table().TritonModelName();
58  fTritonURL = table().TritonURL();
59  fTritonVerbose = table().TritonVerbose();
60  fTritonModelVersion = table().TritonModelVersion();
61 
62  // ... Create the Triton inference client
63  auto err = nic::InferenceServerGrpcClient::Create(&triton_client, fTritonURL, fTritonVerbose);
64  if (!err.IsOk()) {
65  throw cet::exception("PointIdAlgTriton")
66  << "error: unable to create client for inference: " << err << std::endl;
67  }
68 
69  // ... Get the model metadata and config information
71  if (!err.IsOk()) {
72  throw cet::exception("PointIdAlgTriton")
73  << "error: failed to get model metadata: " << err << std::endl;
74  }
76  if (!err.IsOk()) {
77  throw cet::exception("PointIdAlgTriton")
78  << "error: failed to get model config: " << err << std::endl;
79  }
80 
81  // ... Set up shape vector needed when creating inference input
82  triton_inpshape.push_back(1); // initialize batch_size to 1
83  triton_inpshape.push_back(triton_modmet.inputs(0).shape(1));
84  triton_inpshape.push_back(triton_modmet.inputs(0).shape(2));
85  triton_inpshape.push_back(triton_modmet.inputs(0).shape(3));
86 
87  // ... Set up Triton inference client options
88  triton_options.model_name_ = fTritonModelName;
89  triton_options.model_version_ = fTritonModelVersion;
90 
91  mf::LogInfo("PointIdAlgTriton") << "url: " << fTritonURL;
92  mf::LogInfo("PointIdAlgTriton") << "model name: " << fTritonModelName;
93  mf::LogInfo("PointIdAlgTriton") << "model version: " << fTritonModelVersion;
94  mf::LogInfo("PointIdAlgTriton") << "verbose: " << fTritonVerbose;
95 
96  mf::LogInfo("PointIdAlgTriton") << "tensorRT inference context created.";
97 
98  resizePatch();
99  }
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
inference::ModelMetadataResponse triton_modmet
inference::ModelConfigResponse triton_modcfg
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33

Member Function Documentation

std::vector< float > PointIdAlgTools::PointIdAlgTriton::Run ( std::vector< std::vector< float >> const &  inp2d) const
override

Definition at line 102 of file PointIdAlgTriton_tool.cc.

References util::begin(), util::end(), fa(), ncols, triton_client, triton_inpshape, triton_modmet, and triton_options.

103  {
104  size_t nrows = inp2d.size(), ncols = inp2d.front().size();
105 
106  triton_inpshape.at(0) = 1; // set batch size
107 
108  // ~~~~ Initialize the inputs
109 
110  nic::InferInput* triton_input;
111  auto err = nic::InferInput::Create(&triton_input,
112  triton_modmet.inputs(0).name(),
114  triton_modmet.inputs(0).datatype());
115  if (!err.IsOk()) {
116  throw cet::exception("PointIdAlgTriton") << "unable to get input: " << err << std::endl;
117  }
118  std::shared_ptr<nic::InferInput> triton_input_ptr(triton_input);
119  std::vector<nic::InferInput*> triton_inputs = {triton_input_ptr.get()};
120 
121  // ~~~~ Register the mem address of 1st byte of image and #bytes in image
122 
123  err = triton_input_ptr->Reset();
124  if (!err.IsOk()) {
125  throw cet::exception("PointIdAlgTriton")
126  << "failed resetting Triton model input: " << err << std::endl;
127  }
128 
129  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
130  std::vector<float> fa(sbuff_byte_size);
131 
132  // ..flatten the 2d array into contiguous 1d block
133  for (size_t ir = 0; ir < nrows; ++ir) {
134  std::copy(inp2d[ir].begin(), inp2d[ir].end(), fa.begin() + (ir * ncols));
135  }
136  err = triton_input_ptr->AppendRaw(reinterpret_cast<uint8_t*>(fa.data()), sbuff_byte_size);
137  if (!err.IsOk()) {
138  throw cet::exception("PointIdAlgTriton")
139  << "failed setting Triton input: " << err << std::endl;
140  }
141 
142  // ~~~~ Send inference request
143 
144  nic::InferResult* results;
145 
146  err = triton_client->Infer(&results, triton_options, triton_inputs);
147  if (!err.IsOk()) {
148  throw cet::exception("PointIdAlgTriton")
149  << "failed sending Triton synchronous infer request: " << err << std::endl;
150  }
151  std::shared_ptr<nic::InferResult> results_ptr;
152  results_ptr.reset(results);
153 
154  // ~~~~ Retrieve inference results
155 
156  std::vector<float> out;
157 
158  const float* prb0;
159  size_t rbuff0_byte_size; // size of result buffer in bytes
160  results_ptr->RawData(
161  triton_modmet.outputs(0).name(), (const uint8_t**)&prb0, &rbuff0_byte_size);
162  size_t ncat0 = rbuff0_byte_size / sizeof(float);
163 
164  const float* prb1;
165  size_t rbuff1_byte_size; // size of result buffer in bytes
166  results_ptr->RawData(
167  triton_modmet.outputs(1).name(), (const uint8_t**)&prb1, &rbuff1_byte_size);
168  size_t ncat1 = rbuff1_byte_size / sizeof(float);
169 
170  for (unsigned j = 0; j < ncat0; j++)
171  out.push_back(*(prb0 + j));
172  for (unsigned j = 0; j < ncat1; j++)
173  out.push_back(*(prb1 + j));
174 
175  return out;
176  }
inference::ModelMetadataResponse triton_modmet
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
Int_t ncols
Definition: plot.C:52
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
TFile fa("Li7.root")
std::vector< std::vector< float > > PointIdAlgTools::PointIdAlgTriton::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
) const
override

Definition at line 179 of file PointIdAlgTriton_tool.cc.

References util::begin(), DEFINE_ART_CLASS_TOOL, util::end(), fa(), ncols, triton_client, triton_inpshape, triton_modmet, and triton_options.

182  {
183  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
184  return std::vector<std::vector<float>>();
185  }
186 
187  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
188 
189  size_t usamples = samples;
190  size_t nrows = inps.front().size(), ncols = inps.front().front().size();
191 
192  triton_inpshape.at(0) = usamples; // set batch size
193 
194  // ~~~~ Initialize the inputs
195 
196  nic::InferInput* triton_input;
197  auto err = nic::InferInput::Create(&triton_input,
198  triton_modmet.inputs(0).name(),
200  triton_modmet.inputs(0).datatype());
201  if (!err.IsOk()) {
202  throw cet::exception("PointIdAlgTriton") << "unable to get input: " << err << std::endl;
203  }
204  std::shared_ptr<nic::InferInput> triton_input_ptr(triton_input);
205  std::vector<nic::InferInput*> triton_inputs = {triton_input_ptr.get()};
206 
207  // ~~~~ For each sample, register the mem address of 1st byte of image and #bytes in image
208  err = triton_input_ptr->Reset();
209  if (!err.IsOk()) {
210  throw cet::exception("PointIdAlgTriton")
211  << "failed resetting Triton model input: " << err << std::endl;
212  }
213 
214  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
215  std::vector<std::vector<float>> fa(usamples, std::vector<float>(sbuff_byte_size));
216 
217  for (size_t idx = 0; idx < usamples; ++idx) {
218  // ..first flatten the 2d array into contiguous 1d block
219  for (size_t ir = 0; ir < nrows; ++ir) {
220  std::copy(inps[idx][ir].begin(), inps[idx][ir].end(), fa[idx].begin() + (ir * ncols));
221  }
222  err =
223  triton_input_ptr->AppendRaw(reinterpret_cast<uint8_t*>(fa[idx].data()), sbuff_byte_size);
224  if (!err.IsOk()) {
225  throw cet::exception("PointIdAlgTriton")
226  << "failed setting Triton input: " << err << std::endl;
227  }
228  }
229 
230  // ~~~~ Send inference request
231 
232  nic::InferResult* results;
233 
234  err = triton_client->Infer(&results, triton_options, triton_inputs);
235  if (!err.IsOk()) {
236  throw cet::exception("PointIdAlgTriton")
237  << "failed sending Triton synchronous infer request: " << err << std::endl;
238  }
239  std::shared_ptr<nic::InferResult> results_ptr;
240  results_ptr.reset(results);
241 
242  // ~~~~ Retrieve inference results
243 
244  std::vector<std::vector<float>> out;
245 
246  const float* prb0;
247  size_t rbuff0_byte_size; // size of result buffer in bytes
248  results_ptr->RawData(
249  triton_modmet.outputs(0).name(), (const uint8_t**)&prb0, &rbuff0_byte_size);
250  size_t ncat0 = rbuff0_byte_size / (usamples * sizeof(float));
251 
252  const float* prb1;
253  size_t rbuff1_byte_size; // size of result buffer in bytes
254  results_ptr->RawData(
255  triton_modmet.outputs(1).name(), (const uint8_t**)&prb1, &rbuff1_byte_size);
256  size_t ncat1 = rbuff1_byte_size / (usamples * sizeof(float));
257 
258  for (unsigned i = 0; i < usamples; i++) {
259  std::vector<float> vprb;
260  for (unsigned j = 0; j < ncat0; j++)
261  vprb.push_back(*(prb0 + i * ncat0 + j));
262  for (unsigned j = 0; j < ncat1; j++)
263  vprb.push_back(*(prb1 + i * ncat1 + j));
264  out.push_back(vprb);
265  }
266 
267  return out;
268  }
inference::ModelMetadataResponse triton_modmet
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
Int_t ncols
Definition: plot.C:52
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
TFile fa("Li7.root")

Member Data Documentation

std::string PointIdAlgTools::PointIdAlgTriton::fTritonModelName
private

Definition at line 33 of file PointIdAlgTriton_tool.cc.

Referenced by PointIdAlgTriton().

std::string PointIdAlgTools::PointIdAlgTriton::fTritonModelVersion
private

Definition at line 36 of file PointIdAlgTriton_tool.cc.

Referenced by PointIdAlgTriton().

std::string PointIdAlgTools::PointIdAlgTriton::fTritonURL
private

Definition at line 34 of file PointIdAlgTriton_tool.cc.

Referenced by PointIdAlgTriton().

bool PointIdAlgTools::PointIdAlgTriton::fTritonVerbose
private

Definition at line 35 of file PointIdAlgTriton_tool.cc.

Referenced by PointIdAlgTriton().

std::unique_ptr<nic::InferenceServerGrpcClient> PointIdAlgTools::PointIdAlgTriton::triton_client
private

Definition at line 38 of file PointIdAlgTriton_tool.cc.

Referenced by PointIdAlgTriton(), and Run().

std::vector<int64_t> PointIdAlgTools::PointIdAlgTriton::triton_inpshape
mutableprivate

Definition at line 41 of file PointIdAlgTriton_tool.cc.

Referenced by PointIdAlgTriton(), and Run().

inference::ModelConfigResponse PointIdAlgTools::PointIdAlgTriton::triton_modcfg
private

Definition at line 40 of file PointIdAlgTriton_tool.cc.

Referenced by PointIdAlgTriton().

inference::ModelMetadataResponse PointIdAlgTools::PointIdAlgTriton::triton_modmet
private

Definition at line 39 of file PointIdAlgTriton_tool.cc.

Referenced by PointIdAlgTriton(), and Run().

nic::InferOptions PointIdAlgTools::PointIdAlgTriton::triton_options
private

Definition at line 42 of file PointIdAlgTriton_tool.cc.

Referenced by PointIdAlgTriton(), and Run().


The documentation for this class was generated from the following file: