LArSoft  v10_06_00
Liquid Argon Software toolkit - https://larsoft.org/
NuGraphInferenceTriton_module.cc
Go to the documentation of this file.
1 // Class: NuGraphInferenceTriton
3 // Plugin Type: producer (Unknown Unknown)
4 // File: NuGraphInferenceTriton_module.cc
5 //
6 // Generated at Tue Nov 14 14:41:30 2023 by Giuseppe Cerati using cetskelgen
7 // from version .
9 
15 #include "fhiclcpp/ParameterSet.h"
16 #include "fhiclcpp/types/Table.h"
18 
19 #include <array>
20 #include <limits>
21 #include <memory>
22 
28 
29 #include <chrono>
30 #include <fstream>
31 #include <iostream>
32 #include <sstream>
33 #include <string>
34 #include <vector>
35 
36 #include "grpc_client.h"
37 
39 
42 using recob::Hit;
43 using std::array;
44 using std::vector;
45 
46 #define FAIL_IF_ERR(X, MSG) \
47  { \
48  tc::Error err = (X); \
49  if (!err.IsOk()) { \
50  std::cerr << "error: " << (MSG) << ": " << err << std::endl; \
51  exit(1); \
52  } \
53  }
54 namespace tc = triton::client;
55 
57 public:
59 
60  // Plugins should not be copied or assigned.
65 
66  // Required functions.
67  void produce(art::Event& e) override;
68 
69 private:
70  size_t minHits;
71  bool debug;
72  vector<std::string> planes;
73  std::string inference_url;
74  std::string inference_model_name;
75  std::string model_version;
77  std::string ssl_root_certificates;
78  std::string ssl_private_key;
79  std::string ssl_certificate_chain;
80  bool verbose;
81  uint32_t client_timeout;
82 
83  // loader tool
84  std::unique_ptr<LoaderToolBase> _loaderTool;
85  // decoder tools
86  std::vector<std::unique_ptr<DecoderToolBase>> _decoderToolsVec;
87 };
88 
90  : EDProducer{p}
91  , minHits(p.get<size_t>("minHits"))
92  , debug(p.get<bool>("debug"))
93  , planes(p.get<vector<std::string>>("planes"))
94 {
95 
96  fhicl::ParameterSet tritonPset = p.get<fhicl::ParameterSet>("TritonConfig");
97  inference_url = tritonPset.get<std::string>("serverURL");
98  inference_model_name = tritonPset.get<std::string>("modelName");
99  inference_ssl = tritonPset.get<bool>("ssl");
100  ssl_root_certificates = tritonPset.get<std::string>("sslRootCertificates", "");
101  ssl_private_key = tritonPset.get<std::string>("sslPrivateKey", "");
102  ssl_certificate_chain = tritonPset.get<std::string>("sslCertificateChain", "");
103  verbose = tritonPset.get<bool>("verbose", "false");
104  model_version = tritonPset.get<std::string>("modelVersion", "");
105  client_timeout = tritonPset.get<unsigned>("timeout", 0);
106 
107  // Loader Tool
108  _loaderTool = art::make_tool<LoaderToolBase>(p.get<fhicl::ParameterSet>("LoaderTool"));
109  _loaderTool->setDebugAndPlanes(debug, planes);
110 
111  // configure and construct Decoder Tools
112  auto const tool_psets = p.get<fhicl::ParameterSet>("DecoderTools");
113  for (auto const& tool_pset_labels : tool_psets.get_pset_names()) {
114  std::cout << "decoder lablel: " << tool_pset_labels << std::endl;
115  auto const tool_pset = tool_psets.get<fhicl::ParameterSet>(tool_pset_labels);
116  _decoderToolsVec.push_back(art::make_tool<DecoderToolBase>(tool_pset));
117  _decoderToolsVec.back()->setDebugAndPlanes(debug, planes);
118  _decoderToolsVec.back()->declareProducts(producesCollector());
119  }
120 }
121 
123 {
124 
125  //
126  // Load the data and fill the graph inputs
127  //
128  vector<art::Ptr<Hit>> hitlist;
129  vector<vector<size_t>> idsmap;
130  vector<NuGraphInput> graphinputs;
131  _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
132 
133  if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl;
134  if (hitlist.size() < minHits) {
135  // Writing the empty outputs to the output root file
136  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
137  _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap);
138  }
139  return;
140  }
141 
142  //
143  // Triton-specific section
144  //
145  const vector<int32_t>* hit_table_hit_id_data = nullptr;
146  const vector<int32_t>* hit_table_local_plane_data = nullptr;
147  const vector<float>* hit_table_local_time_data = nullptr;
148  const vector<int32_t>* hit_table_local_wire_data = nullptr;
149  const vector<float>* hit_table_integral_data = nullptr;
150  const vector<float>* hit_table_rms_data = nullptr;
151  const vector<int32_t>* spacepoint_table_spacepoint_id_data = nullptr;
152  const vector<int32_t>* spacepoint_table_hit_id_u_data = nullptr;
153  const vector<int32_t>* spacepoint_table_hit_id_v_data = nullptr;
154  const vector<int32_t>* spacepoint_table_hit_id_y_data = nullptr;
155  for (const auto& gi : graphinputs) {
156  if (gi.input_name == "hit_table_hit_id")
157  hit_table_hit_id_data = &gi.input_int32_vec;
158  else if (gi.input_name == "hit_table_local_plane")
159  hit_table_local_plane_data = &gi.input_int32_vec;
160  else if (gi.input_name == "hit_table_local_time")
161  hit_table_local_time_data = &gi.input_float_vec;
162  else if (gi.input_name == "hit_table_local_wire")
163  hit_table_local_wire_data = &gi.input_int32_vec;
164  else if (gi.input_name == "hit_table_integral")
165  hit_table_integral_data = &gi.input_float_vec;
166  else if (gi.input_name == "hit_table_rms")
167  hit_table_rms_data = &gi.input_float_vec;
168  else if (gi.input_name == "spacepoint_table_spacepoint_id")
169  spacepoint_table_spacepoint_id_data = &gi.input_int32_vec;
170  else if (gi.input_name == "spacepoint_table_hit_id_u")
171  spacepoint_table_hit_id_u_data = &gi.input_int32_vec;
172  else if (gi.input_name == "spacepoint_table_hit_id_v")
173  spacepoint_table_hit_id_v_data = &gi.input_int32_vec;
174  else if (gi.input_name == "spacepoint_table_hit_id_y")
175  spacepoint_table_hit_id_y_data = &gi.input_int32_vec;
176  }
177 
178  //Here the input should be sent to Triton
179  tc::Headers http_headers;
180  grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
181  bool test_use_cached_channel = false;
182  bool use_cached_channel = true;
183 
184  // Create a InferenceServerGrpcClient instance to communicate with the
185  // server using gRPC protocol.
186  std::unique_ptr<tc::InferenceServerGrpcClient> client;
187  tc::SslOptions ssl_options = tc::SslOptions();
188  std::string err;
189  if (inference_ssl) {
190  ssl_options.root_certificates = ssl_root_certificates;
191  ssl_options.private_key = ssl_private_key;
192  ssl_options.certificate_chain = ssl_certificate_chain;
193  err = "unable to create secure grpc client";
194  }
195  else {
196  err = "unable to create grpc client";
197  }
198  // Run with the same name to ensure cached channel is not used
199  int numRuns = test_use_cached_channel ? 2 : 1;
200  for (int i = 0; i < numRuns; ++i) {
201  FAIL_IF_ERR(tc::InferenceServerGrpcClient::Create(&client,
203  verbose,
205  ssl_options,
206  tc::KeepAliveOptions(),
207  use_cached_channel),
208  err);
209 
210  std::vector<int64_t> hit_table_shape{int64_t(hit_table_hit_id_data->size())};
211  std::vector<int64_t> spacepoint_table_shape{
212  int64_t(spacepoint_table_spacepoint_id_data->size())};
213 
214  // Initialize the inputs with the data.
215  tc::InferInput* hit_table_hit_id;
216  tc::InferInput* hit_table_local_plane;
217  tc::InferInput* hit_table_local_time;
218  tc::InferInput* hit_table_local_wire;
219  tc::InferInput* hit_table_integral;
220  tc::InferInput* hit_table_rms;
221 
222  tc::InferInput* spacepoint_table_spacepoint_id;
223  tc::InferInput* spacepoint_table_hit_id_u;
224  tc::InferInput* spacepoint_table_hit_id_v;
225  tc::InferInput* spacepoint_table_hit_id_y;
226 
227  FAIL_IF_ERR(
228  tc::InferInput::Create(&hit_table_hit_id, "hit_table_hit_id", hit_table_shape, "INT32"),
229  "unable to get hit_table_hit_id");
230  std::shared_ptr<tc::InferInput> hit_table_hit_id_ptr;
231  hit_table_hit_id_ptr.reset(hit_table_hit_id);
232 
233  FAIL_IF_ERR(tc::InferInput::Create(
234  &hit_table_local_plane, "hit_table_local_plane", hit_table_shape, "INT32"),
235  "unable to get hit_table_local_plane");
236  std::shared_ptr<tc::InferInput> hit_table_local_plane_ptr;
237  hit_table_local_plane_ptr.reset(hit_table_local_plane);
238 
239  FAIL_IF_ERR(tc::InferInput::Create(
240  &hit_table_local_time, "hit_table_local_time", hit_table_shape, "FP32"),
241  "unable to get hit_table_local_time");
242  std::shared_ptr<tc::InferInput> hit_table_local_time_ptr;
243  hit_table_local_time_ptr.reset(hit_table_local_time);
244 
245  FAIL_IF_ERR(tc::InferInput::Create(
246  &hit_table_local_wire, "hit_table_local_wire", hit_table_shape, "INT32"),
247  "unable to get hit_table_local_wire");
248  std::shared_ptr<tc::InferInput> hit_table_local_wire_ptr;
249  hit_table_local_wire_ptr.reset(hit_table_local_wire);
250 
251  FAIL_IF_ERR(
252  tc::InferInput::Create(&hit_table_integral, "hit_table_integral", hit_table_shape, "FP32"),
253  "unable to get hit_table_integral");
254  std::shared_ptr<tc::InferInput> hit_table_integral_ptr;
255  hit_table_integral_ptr.reset(hit_table_integral);
256 
257  FAIL_IF_ERR(tc::InferInput::Create(&hit_table_rms, "hit_table_rms", hit_table_shape, "FP32"),
258  "unable to get hit_table_rms");
259  std::shared_ptr<tc::InferInput> hit_table_rms_ptr;
260  hit_table_rms_ptr.reset(hit_table_rms);
261 
262  FAIL_IF_ERR(tc::InferInput::Create(&spacepoint_table_spacepoint_id,
263  "spacepoint_table_spacepoint_id",
264  spacepoint_table_shape,
265  "INT32"),
266  "unable to get spacepoint_table_spacepoint_id");
267  std::shared_ptr<tc::InferInput> spacepoint_table_spacepoint_id_ptr;
268  spacepoint_table_spacepoint_id_ptr.reset(spacepoint_table_spacepoint_id);
269 
270  FAIL_IF_ERR(
271  tc::InferInput::Create(
272  &spacepoint_table_hit_id_u, "spacepoint_table_hit_id_u", spacepoint_table_shape, "INT32"),
273  "unable to get spacepoint_table_spacepoint_hit_id_u");
274  std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_u_ptr;
275  spacepoint_table_hit_id_u_ptr.reset(spacepoint_table_hit_id_u);
276 
277  FAIL_IF_ERR(
278  tc::InferInput::Create(
279  &spacepoint_table_hit_id_v, "spacepoint_table_hit_id_v", spacepoint_table_shape, "INT32"),
280  "unable to get spacepoint_table_spacepoint_hit_id_v");
281  std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_v_ptr;
282  spacepoint_table_hit_id_v_ptr.reset(spacepoint_table_hit_id_v);
283 
284  FAIL_IF_ERR(
285  tc::InferInput::Create(
286  &spacepoint_table_hit_id_y, "spacepoint_table_hit_id_y", spacepoint_table_shape, "INT32"),
287  "unable to get spacepoint_table_spacepoint_hit_id_y");
288  std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_y_ptr;
289  spacepoint_table_hit_id_y_ptr.reset(spacepoint_table_hit_id_y);
290 
291  FAIL_IF_ERR(hit_table_hit_id_ptr->AppendRaw(
292  reinterpret_cast<const uint8_t*>(hit_table_hit_id_data->data()),
293  hit_table_hit_id_data->size() * sizeof(int32_t)),
294  "unable to set data for hit_table_hit_id");
295 
296  FAIL_IF_ERR(hit_table_local_plane_ptr->AppendRaw(
297  reinterpret_cast<const uint8_t*>(hit_table_local_plane_data->data()),
298  hit_table_local_plane_data->size() * sizeof(int32_t)),
299  "unable to set data for hit_table_local_plane");
300 
301  FAIL_IF_ERR(hit_table_local_time_ptr->AppendRaw(
302  reinterpret_cast<const uint8_t*>(hit_table_local_time_data->data()),
303  hit_table_local_time_data->size() * sizeof(float)),
304  "unable to set data for hit_table_local_time");
305 
306  FAIL_IF_ERR(hit_table_local_wire_ptr->AppendRaw(
307  reinterpret_cast<const uint8_t*>(hit_table_local_wire_data->data()),
308  hit_table_local_wire_data->size() * sizeof(int32_t)),
309  "unable to set data for hit_table_local_wire");
310 
311  FAIL_IF_ERR(hit_table_integral_ptr->AppendRaw(
312  reinterpret_cast<const uint8_t*>(hit_table_integral_data->data()),
313  hit_table_integral_data->size() * sizeof(float)),
314  "unable to set data for hit_table_integral");
315 
316  FAIL_IF_ERR(
317  hit_table_rms_ptr->AppendRaw(reinterpret_cast<const uint8_t*>(hit_table_rms_data->data()),
318  hit_table_rms_data->size() * sizeof(float)),
319  "unable to set data for hit_table_rms");
320 
321  FAIL_IF_ERR(spacepoint_table_spacepoint_id_ptr->AppendRaw(
322  reinterpret_cast<const uint8_t*>(spacepoint_table_spacepoint_id_data->data()),
323  spacepoint_table_spacepoint_id_data->size() * sizeof(int32_t)),
324  "unable to set data for spacepoint_table_spacepoint_id");
325 
326  FAIL_IF_ERR(spacepoint_table_hit_id_u_ptr->AppendRaw(
327  reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_u_data->data()),
328  spacepoint_table_hit_id_u_data->size() * sizeof(int32_t)),
329  "unable to set data for spacepoint_table_hit_id_u");
330 
331  FAIL_IF_ERR(spacepoint_table_hit_id_v_ptr->AppendRaw(
332  reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_v_data->data()),
333  spacepoint_table_hit_id_v_data->size() * sizeof(int32_t)),
334  "unable to set data for spacepoint_table_hit_id_v");
335 
336  FAIL_IF_ERR(spacepoint_table_hit_id_y_ptr->AppendRaw(
337  reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_y_data->data()),
338  spacepoint_table_hit_id_y_data->size() * sizeof(int32_t)),
339  "unable to set data for spacepoint_table_hit_id_y");
340 
341  // Generate the outputs to be requested.
342  tc::InferRequestedOutput* x_semantic_u;
343  tc::InferRequestedOutput* x_semantic_v;
344  tc::InferRequestedOutput* x_semantic_y;
345  tc::InferRequestedOutput* x_filter_u;
346  tc::InferRequestedOutput* x_filter_v;
347  tc::InferRequestedOutput* x_filter_y;
348 
349  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_u, "x_semantic_u"),
350  "unable to get 'x_semantic_u'");
351  std::shared_ptr<tc::InferRequestedOutput> x_semantic_u_ptr;
352  x_semantic_u_ptr.reset(x_semantic_u);
353 
354  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_v, "x_semantic_v"),
355  "unable to get 'x_semantic_v'");
356  std::shared_ptr<tc::InferRequestedOutput> x_semantic_v_ptr;
357  x_semantic_v_ptr.reset(x_semantic_v);
358 
359  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_y, "x_semantic_y"),
360  "unable to get 'x_semantic_y'");
361  std::shared_ptr<tc::InferRequestedOutput> x_semantic_y_ptr;
362  x_semantic_y_ptr.reset(x_semantic_y);
363 
364  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_u, "x_filter_u"),
365  "unable to get 'x_filter_u'");
366  std::shared_ptr<tc::InferRequestedOutput> x_filter_u_ptr;
367  x_filter_u_ptr.reset(x_filter_u);
368 
369  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_v, "x_filter_v"),
370  "unable to get 'x_filter_v'");
371  std::shared_ptr<tc::InferRequestedOutput> x_filter_v_ptr;
372  x_filter_v_ptr.reset(x_filter_v);
373 
374  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_y, "x_filter_y"),
375  "unable to get 'x_filter_y'");
376  std::shared_ptr<tc::InferRequestedOutput> x_filter_y_ptr;
377  x_filter_y_ptr.reset(x_filter_y);
378 
379  // The inference settings. Will be using default for now.
380  tc::InferOptions options(inference_model_name);
381  options.model_version_ = model_version;
382  options.client_timeout_ = client_timeout;
383 
384  std::vector<tc::InferInput*> inputs = {hit_table_hit_id_ptr.get(),
385  hit_table_local_plane_ptr.get(),
386  hit_table_local_time_ptr.get(),
387  hit_table_local_wire_ptr.get(),
388  hit_table_integral_ptr.get(),
389  hit_table_rms_ptr.get(),
390  spacepoint_table_spacepoint_id_ptr.get(),
391  spacepoint_table_hit_id_u_ptr.get(),
392  spacepoint_table_hit_id_v_ptr.get(),
393  spacepoint_table_hit_id_y_ptr.get()};
394 
395  std::vector<const tc::InferRequestedOutput*> outputs = {x_semantic_u_ptr.get(),
396  x_semantic_v_ptr.get(),
397  x_semantic_y_ptr.get(),
398  x_filter_u_ptr.get(),
399  x_filter_v_ptr.get(),
400  x_filter_y_ptr.get()};
401 
402  tc::InferResult* results;
403  auto start = std::chrono::high_resolution_clock::now();
404  FAIL_IF_ERR(
405  client->Infer(&results, options, inputs, outputs, http_headers, compression_algorithm),
406  "unable to run model");
407  auto end = std::chrono::high_resolution_clock::now();
408  std::chrono::duration<double> elapsed = end - start;
409  std::cout << "Time taken for inference: " << elapsed.count() << " seconds" << std::endl;
410  std::shared_ptr<tc::InferResult> results_ptr;
411  results_ptr.reset(results);
412 
413  //
414  // Get pointers to the result returned and write to the event
415  //
416  vector<NuGraphOutput> infer_output;
417  vector<string> outnames = {
418  "x_semantic_u", "x_semantic_v", "x_semantic_y", "x_filter_u", "x_filter_v", "x_filter_y"};
419  for (const auto& name : outnames) {
420  const float* _data;
421  size_t _byte_size;
422  FAIL_IF_ERR(results_ptr->RawData(name, (const uint8_t**)&_data, &_byte_size),
423  "unable to get result data for " + name);
424  size_t n_elements = _byte_size / sizeof(float);
425  std::vector<float> out_data(_data, _data + n_elements);
426  infer_output.push_back(NuGraphOutput(name, out_data));
427  }
428 
429  // Write the outputs
430  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
431  _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output);
432  }
433  }
434 }
Declaration of signal hit object.
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.cc:6
#define FAIL_IF_ERR(X, MSG)
NuGraphInferenceTriton & operator=(NuGraphInferenceTriton const &)=delete
NuGraphInferenceTriton(fhicl::ParameterSet const &p)
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
auto array(Array const &a)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:250
#define DEFINE_ART_MODULE(klass)
Definition: ModuleMacros.h:65
T get(std::string const &key) const
Definition: ParameterSet.h:314
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
std::unique_ptr< LoaderToolBase > _loaderTool
ProducesCollector & producesCollector() noexcept
2D representation of charge deposited in the TDC/wire plane
Definition: Hit.h:46
void produce(art::Event &e) override
Float_t e
Definition: plot.C:35