LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
NuGraphInferenceTriton_module.cc
Go to the documentation of this file.
1 // Class: NuGraphInferenceTriton
3 // Plugin Type: producer (Unknown Unknown)
4 // File: NuGraphInferenceTriton_module.cc
5 //
6 // Generated at Tue Nov 14 14:41:30 2023 by Giuseppe Cerati using cetskelgen
7 // from version .
9 
14 #include "fhiclcpp/ParameterSet.h"
15 #include "fhiclcpp/types/Table.h"
17 
18 #include <array>
19 #include <limits>
20 #include <memory>
21 
27 
28 #include <chrono>
29 #include <fstream>
30 #include <iostream>
31 #include <sstream>
32 #include <string>
33 #include <vector>
34 
35 #include "grpc_client.h"
36 
38 
41 using recob::Hit;
42 using std::array;
43 using std::vector;
44 
45 #define FAIL_IF_ERR(X, MSG) \
46  { \
47  tc::Error err = (X); \
48  if (!err.IsOk()) { \
49  std::cerr << "error: " << (MSG) << ": " << err << std::endl; \
50  exit(1); \
51  } \
52  }
53 namespace tc = triton::client;
54 
56 public:
58 
59  // Plugins should not be copied or assigned.
64 
65  // Required functions.
66  void produce(art::Event& e) override;
67 
68 private:
69  size_t minHits;
70  bool debug;
71  vector<std::string> planes;
72  std::string inference_url;
73  std::string inference_model_name;
74  std::string model_version;
76  std::string ssl_root_certificates;
77  std::string ssl_private_key;
78  std::string ssl_certificate_chain;
79  bool verbose;
80  uint32_t client_timeout;
81 
82  // loader tool
83  std::unique_ptr<LoaderToolBase> _loaderTool;
84  // decoder tools
85  std::vector<std::unique_ptr<DecoderToolBase>> _decoderToolsVec;
86 };
87 
89  : EDProducer{p}
90  , minHits(p.get<size_t>("minHits"))
91  , debug(p.get<bool>("debug"))
92  , planes(p.get<vector<std::string>>("planes"))
93 {
94 
95  fhicl::ParameterSet tritonPset = p.get<fhicl::ParameterSet>("TritonConfig");
96  inference_url = tritonPset.get<std::string>("serverURL");
97  inference_model_name = tritonPset.get<std::string>("modelName");
98  inference_ssl = tritonPset.get<bool>("ssl");
99  ssl_root_certificates = tritonPset.get<std::string>("sslRootCertificates", "");
100  ssl_private_key = tritonPset.get<std::string>("sslPrivateKey", "");
101  ssl_certificate_chain = tritonPset.get<std::string>("sslCertificateChain", "");
102  verbose = tritonPset.get<bool>("verbose", "false");
103  model_version = tritonPset.get<std::string>("modelVersion", "");
104  client_timeout = tritonPset.get<unsigned>("timeout", 0);
105 
106  // Loader Tool
107  _loaderTool = art::make_tool<LoaderToolBase>(p.get<fhicl::ParameterSet>("LoaderTool"));
108  _loaderTool->setDebugAndPlanes(debug, planes);
109 
110  // configure and construct Decoder Tools
111  auto const tool_psets = p.get<fhicl::ParameterSet>("DecoderTools");
112  for (auto const& tool_pset_labels : tool_psets.get_pset_names()) {
113  std::cout << "decoder lablel: " << tool_pset_labels << std::endl;
114  auto const tool_pset = tool_psets.get<fhicl::ParameterSet>(tool_pset_labels);
115  _decoderToolsVec.push_back(art::make_tool<DecoderToolBase>(tool_pset));
116  _decoderToolsVec.back()->setDebugAndPlanes(debug, planes);
117  _decoderToolsVec.back()->declareProducts(producesCollector());
118  }
119 }
120 
122 {
123 
124  //
125  // Load the data and fill the graph inputs
126  //
127  vector<art::Ptr<Hit>> hitlist;
128  vector<vector<size_t>> idsmap;
129  vector<NuGraphInput> graphinputs;
130  _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
131 
132  if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl;
133  if (hitlist.size() < minHits) {
134  // Writing the empty outputs to the output root file
135  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
136  _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap);
137  }
138  return;
139  }
140 
141  //
142  // Triton-specific section
143  //
144  const vector<int32_t>* hit_table_hit_id_data = nullptr;
145  const vector<int32_t>* hit_table_local_plane_data = nullptr;
146  const vector<float>* hit_table_local_time_data = nullptr;
147  const vector<int32_t>* hit_table_local_wire_data = nullptr;
148  const vector<float>* hit_table_integral_data = nullptr;
149  const vector<float>* hit_table_rms_data = nullptr;
150  const vector<int32_t>* spacepoint_table_spacepoint_id_data = nullptr;
151  const vector<int32_t>* spacepoint_table_hit_id_u_data = nullptr;
152  const vector<int32_t>* spacepoint_table_hit_id_v_data = nullptr;
153  const vector<int32_t>* spacepoint_table_hit_id_y_data = nullptr;
154  for (const auto& gi : graphinputs) {
155  if (gi.input_name == "hit_table_hit_id")
156  hit_table_hit_id_data = &gi.input_int32_vec;
157  else if (gi.input_name == "hit_table_local_plane")
158  hit_table_local_plane_data = &gi.input_int32_vec;
159  else if (gi.input_name == "hit_table_local_time")
160  hit_table_local_time_data = &gi.input_float_vec;
161  else if (gi.input_name == "hit_table_local_wire")
162  hit_table_local_wire_data = &gi.input_int32_vec;
163  else if (gi.input_name == "hit_table_integral")
164  hit_table_integral_data = &gi.input_float_vec;
165  else if (gi.input_name == "hit_table_rms")
166  hit_table_rms_data = &gi.input_float_vec;
167  else if (gi.input_name == "spacepoint_table_spacepoint_id")
168  spacepoint_table_spacepoint_id_data = &gi.input_int32_vec;
169  else if (gi.input_name == "spacepoint_table_hit_id_u")
170  spacepoint_table_hit_id_u_data = &gi.input_int32_vec;
171  else if (gi.input_name == "spacepoint_table_hit_id_v")
172  spacepoint_table_hit_id_v_data = &gi.input_int32_vec;
173  else if (gi.input_name == "spacepoint_table_hit_id_y")
174  spacepoint_table_hit_id_y_data = &gi.input_int32_vec;
175  }
176 
177  //Here the input should be sent to Triton
178  tc::Headers http_headers;
179  grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
180  bool test_use_cached_channel = false;
181  bool use_cached_channel = true;
182 
183  // Create a InferenceServerGrpcClient instance to communicate with the
184  // server using gRPC protocol.
185  std::unique_ptr<tc::InferenceServerGrpcClient> client;
186  tc::SslOptions ssl_options = tc::SslOptions();
187  std::string err;
188  if (inference_ssl) {
189  ssl_options.root_certificates = ssl_root_certificates;
190  ssl_options.private_key = ssl_private_key;
191  ssl_options.certificate_chain = ssl_certificate_chain;
192  err = "unable to create secure grpc client";
193  }
194  else {
195  err = "unable to create grpc client";
196  }
197  // Run with the same name to ensure cached channel is not used
198  int numRuns = test_use_cached_channel ? 2 : 1;
199  for (int i = 0; i < numRuns; ++i) {
200  FAIL_IF_ERR(tc::InferenceServerGrpcClient::Create(&client,
202  verbose,
204  ssl_options,
205  tc::KeepAliveOptions(),
206  use_cached_channel),
207  err);
208 
209  std::vector<int64_t> hit_table_shape{int64_t(hit_table_hit_id_data->size())};
210  std::vector<int64_t> spacepoint_table_shape{
211  int64_t(spacepoint_table_spacepoint_id_data->size())};
212 
213  // Initialize the inputs with the data.
214  tc::InferInput* hit_table_hit_id;
215  tc::InferInput* hit_table_local_plane;
216  tc::InferInput* hit_table_local_time;
217  tc::InferInput* hit_table_local_wire;
218  tc::InferInput* hit_table_integral;
219  tc::InferInput* hit_table_rms;
220 
221  tc::InferInput* spacepoint_table_spacepoint_id;
222  tc::InferInput* spacepoint_table_hit_id_u;
223  tc::InferInput* spacepoint_table_hit_id_v;
224  tc::InferInput* spacepoint_table_hit_id_y;
225 
226  FAIL_IF_ERR(
227  tc::InferInput::Create(&hit_table_hit_id, "hit_table_hit_id", hit_table_shape, "INT32"),
228  "unable to get hit_table_hit_id");
229  std::shared_ptr<tc::InferInput> hit_table_hit_id_ptr;
230  hit_table_hit_id_ptr.reset(hit_table_hit_id);
231 
232  FAIL_IF_ERR(tc::InferInput::Create(
233  &hit_table_local_plane, "hit_table_local_plane", hit_table_shape, "INT32"),
234  "unable to get hit_table_local_plane");
235  std::shared_ptr<tc::InferInput> hit_table_local_plane_ptr;
236  hit_table_local_plane_ptr.reset(hit_table_local_plane);
237 
238  FAIL_IF_ERR(tc::InferInput::Create(
239  &hit_table_local_time, "hit_table_local_time", hit_table_shape, "FP32"),
240  "unable to get hit_table_local_time");
241  std::shared_ptr<tc::InferInput> hit_table_local_time_ptr;
242  hit_table_local_time_ptr.reset(hit_table_local_time);
243 
244  FAIL_IF_ERR(tc::InferInput::Create(
245  &hit_table_local_wire, "hit_table_local_wire", hit_table_shape, "INT32"),
246  "unable to get hit_table_local_wire");
247  std::shared_ptr<tc::InferInput> hit_table_local_wire_ptr;
248  hit_table_local_wire_ptr.reset(hit_table_local_wire);
249 
250  FAIL_IF_ERR(
251  tc::InferInput::Create(&hit_table_integral, "hit_table_integral", hit_table_shape, "FP32"),
252  "unable to get hit_table_integral");
253  std::shared_ptr<tc::InferInput> hit_table_integral_ptr;
254  hit_table_integral_ptr.reset(hit_table_integral);
255 
256  FAIL_IF_ERR(tc::InferInput::Create(&hit_table_rms, "hit_table_rms", hit_table_shape, "FP32"),
257  "unable to get hit_table_rms");
258  std::shared_ptr<tc::InferInput> hit_table_rms_ptr;
259  hit_table_rms_ptr.reset(hit_table_rms);
260 
261  FAIL_IF_ERR(tc::InferInput::Create(&spacepoint_table_spacepoint_id,
262  "spacepoint_table_spacepoint_id",
263  spacepoint_table_shape,
264  "INT32"),
265  "unable to get spacepoint_table_spacepoint_id");
266  std::shared_ptr<tc::InferInput> spacepoint_table_spacepoint_id_ptr;
267  spacepoint_table_spacepoint_id_ptr.reset(spacepoint_table_spacepoint_id);
268 
269  FAIL_IF_ERR(
270  tc::InferInput::Create(
271  &spacepoint_table_hit_id_u, "spacepoint_table_hit_id_u", spacepoint_table_shape, "INT32"),
272  "unable to get spacepoint_table_spacepoint_hit_id_u");
273  std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_u_ptr;
274  spacepoint_table_hit_id_u_ptr.reset(spacepoint_table_hit_id_u);
275 
276  FAIL_IF_ERR(
277  tc::InferInput::Create(
278  &spacepoint_table_hit_id_v, "spacepoint_table_hit_id_v", spacepoint_table_shape, "INT32"),
279  "unable to get spacepoint_table_spacepoint_hit_id_v");
280  std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_v_ptr;
281  spacepoint_table_hit_id_v_ptr.reset(spacepoint_table_hit_id_v);
282 
283  FAIL_IF_ERR(
284  tc::InferInput::Create(
285  &spacepoint_table_hit_id_y, "spacepoint_table_hit_id_y", spacepoint_table_shape, "INT32"),
286  "unable to get spacepoint_table_spacepoint_hit_id_y");
287  std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_y_ptr;
288  spacepoint_table_hit_id_y_ptr.reset(spacepoint_table_hit_id_y);
289 
290  FAIL_IF_ERR(hit_table_hit_id_ptr->AppendRaw(
291  reinterpret_cast<const uint8_t*>(hit_table_hit_id_data->data()),
292  hit_table_hit_id_data->size() * sizeof(int32_t)),
293  "unable to set data for hit_table_hit_id");
294 
295  FAIL_IF_ERR(hit_table_local_plane_ptr->AppendRaw(
296  reinterpret_cast<const uint8_t*>(hit_table_local_plane_data->data()),
297  hit_table_local_plane_data->size() * sizeof(int32_t)),
298  "unable to set data for hit_table_local_plane");
299 
300  FAIL_IF_ERR(hit_table_local_time_ptr->AppendRaw(
301  reinterpret_cast<const uint8_t*>(hit_table_local_time_data->data()),
302  hit_table_local_time_data->size() * sizeof(float)),
303  "unable to set data for hit_table_local_time");
304 
305  FAIL_IF_ERR(hit_table_local_wire_ptr->AppendRaw(
306  reinterpret_cast<const uint8_t*>(hit_table_local_wire_data->data()),
307  hit_table_local_wire_data->size() * sizeof(int32_t)),
308  "unable to set data for hit_table_local_wire");
309 
310  FAIL_IF_ERR(hit_table_integral_ptr->AppendRaw(
311  reinterpret_cast<const uint8_t*>(hit_table_integral_data->data()),
312  hit_table_integral_data->size() * sizeof(float)),
313  "unable to set data for hit_table_integral");
314 
315  FAIL_IF_ERR(
316  hit_table_rms_ptr->AppendRaw(reinterpret_cast<const uint8_t*>(hit_table_rms_data->data()),
317  hit_table_rms_data->size() * sizeof(float)),
318  "unable to set data for hit_table_rms");
319 
320  FAIL_IF_ERR(spacepoint_table_spacepoint_id_ptr->AppendRaw(
321  reinterpret_cast<const uint8_t*>(spacepoint_table_spacepoint_id_data->data()),
322  spacepoint_table_spacepoint_id_data->size() * sizeof(int32_t)),
323  "unable to set data for spacepoint_table_spacepoint_id");
324 
325  FAIL_IF_ERR(spacepoint_table_hit_id_u_ptr->AppendRaw(
326  reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_u_data->data()),
327  spacepoint_table_hit_id_u_data->size() * sizeof(int32_t)),
328  "unable to set data for spacepoint_table_hit_id_u");
329 
330  FAIL_IF_ERR(spacepoint_table_hit_id_v_ptr->AppendRaw(
331  reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_v_data->data()),
332  spacepoint_table_hit_id_v_data->size() * sizeof(int32_t)),
333  "unable to set data for spacepoint_table_hit_id_v");
334 
335  FAIL_IF_ERR(spacepoint_table_hit_id_y_ptr->AppendRaw(
336  reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_y_data->data()),
337  spacepoint_table_hit_id_y_data->size() * sizeof(int32_t)),
338  "unable to set data for spacepoint_table_hit_id_y");
339 
340  // Generate the outputs to be requested.
341  tc::InferRequestedOutput* x_semantic_u;
342  tc::InferRequestedOutput* x_semantic_v;
343  tc::InferRequestedOutput* x_semantic_y;
344  tc::InferRequestedOutput* x_filter_u;
345  tc::InferRequestedOutput* x_filter_v;
346  tc::InferRequestedOutput* x_filter_y;
347 
348  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_u, "x_semantic_u"),
349  "unable to get 'x_semantic_u'");
350  std::shared_ptr<tc::InferRequestedOutput> x_semantic_u_ptr;
351  x_semantic_u_ptr.reset(x_semantic_u);
352 
353  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_v, "x_semantic_v"),
354  "unable to get 'x_semantic_v'");
355  std::shared_ptr<tc::InferRequestedOutput> x_semantic_v_ptr;
356  x_semantic_v_ptr.reset(x_semantic_v);
357 
358  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_y, "x_semantic_y"),
359  "unable to get 'x_semantic_y'");
360  std::shared_ptr<tc::InferRequestedOutput> x_semantic_y_ptr;
361  x_semantic_y_ptr.reset(x_semantic_y);
362 
363  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_u, "x_filter_u"),
364  "unable to get 'x_filter_u'");
365  std::shared_ptr<tc::InferRequestedOutput> x_filter_u_ptr;
366  x_filter_u_ptr.reset(x_filter_u);
367 
368  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_v, "x_filter_v"),
369  "unable to get 'x_filter_v'");
370  std::shared_ptr<tc::InferRequestedOutput> x_filter_v_ptr;
371  x_filter_v_ptr.reset(x_filter_v);
372 
373  FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_y, "x_filter_y"),
374  "unable to get 'x_filter_y'");
375  std::shared_ptr<tc::InferRequestedOutput> x_filter_y_ptr;
376  x_filter_y_ptr.reset(x_filter_y);
377 
378  // The inference settings. Will be using default for now.
379  tc::InferOptions options(inference_model_name);
380  options.model_version_ = model_version;
381  options.client_timeout_ = client_timeout;
382 
383  std::vector<tc::InferInput*> inputs = {hit_table_hit_id_ptr.get(),
384  hit_table_local_plane_ptr.get(),
385  hit_table_local_time_ptr.get(),
386  hit_table_local_wire_ptr.get(),
387  hit_table_integral_ptr.get(),
388  hit_table_rms_ptr.get(),
389  spacepoint_table_spacepoint_id_ptr.get(),
390  spacepoint_table_hit_id_u_ptr.get(),
391  spacepoint_table_hit_id_v_ptr.get(),
392  spacepoint_table_hit_id_y_ptr.get()};
393 
394  std::vector<const tc::InferRequestedOutput*> outputs = {x_semantic_u_ptr.get(),
395  x_semantic_v_ptr.get(),
396  x_semantic_y_ptr.get(),
397  x_filter_u_ptr.get(),
398  x_filter_v_ptr.get(),
399  x_filter_y_ptr.get()};
400 
401  tc::InferResult* results;
402  auto start = std::chrono::high_resolution_clock::now();
403  FAIL_IF_ERR(
404  client->Infer(&results, options, inputs, outputs, http_headers, compression_algorithm),
405  "unable to run model");
406  auto end = std::chrono::high_resolution_clock::now();
407  std::chrono::duration<double> elapsed = end - start;
408  std::cout << "Time taken for inference: " << elapsed.count() << " seconds" << std::endl;
409  std::shared_ptr<tc::InferResult> results_ptr;
410  results_ptr.reset(results);
411 
412  //
413  // Get pointers to the result returned and write to the event
414  //
415  vector<NuGraphOutput> infer_output;
416  vector<string> outnames = {
417  "x_semantic_u", "x_semantic_v", "x_semantic_y", "x_filter_u", "x_filter_v", "x_filter_y"};
418  for (const auto& name : outnames) {
419  const float* _data;
420  size_t _byte_size;
421  FAIL_IF_ERR(results_ptr->RawData(name, (const uint8_t**)&_data, &_byte_size),
422  "unable to get result data for " + name);
423  size_t n_elements = _byte_size / sizeof(float);
424  std::vector<float> out_data(_data, _data + n_elements);
425  infer_output.push_back(NuGraphOutput(name, out_data));
426  }
427 
428  // Write the outputs
429  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
430  _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output);
431  }
432  }
433 }
Declaration of signal hit object.
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.cc:6
#define FAIL_IF_ERR(X, MSG)
NuGraphInferenceTriton & operator=(NuGraphInferenceTriton const &)=delete
NuGraphInferenceTriton(fhicl::ParameterSet const &p)
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
auto array(Array const &a)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:250
#define DEFINE_ART_MODULE(klass)
Definition: ModuleMacros.h:65
T get(std::string const &key) const
Definition: ParameterSet.h:314
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
std::unique_ptr< LoaderToolBase > _loaderTool
ProducesCollector & producesCollector() noexcept
2D representation of charge deposited in the TDC/wire plane
Definition: Hit.h:46
void produce(art::Event &e) override
Float_t e
Definition: plot.C:35