LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
NuGraphInferenceSonicTriton_module.cc
Go to the documentation of this file.
1 // Class: NuGraphInferenceSonicTriton
3 // Plugin Type: producer (Unknown Unknown)
4 // File: NuGraphInferenceSonicTriton_module.cc
5 //
6 // Generated at Tue Nov 14 14:41:30 2023 by Giuseppe Cerati using cetskelgen
7 // from version .
9 
14 #include "fhiclcpp/ParameterSet.h"
15 #include "fhiclcpp/types/Table.h"
17 
18 #include <limits>
19 #include <memory>
20 
26 
29 
30 #include <chrono>
31 #include <fstream>
32 #include <iostream>
33 #include <sstream>
34 #include <string>
35 #include <vector>
36 
38 
41 using recob::Hit;
42 using std::vector;
43 
45 public:
47 
48  // Plugins should not be copied or assigned.
53 
54  // Required functions.
55  void produce(art::Event& e) override;
56 
57 private:
58  size_t minHits;
59  bool debug;
60  vector<std::string> planes;
62  std::unique_ptr<lartriton::TritonClient> triton_client;
63 
64  // loader tool
65  std::unique_ptr<LoaderToolBase> _loaderTool;
66  // decoder tools
67  std::vector<std::unique_ptr<DecoderToolBase>> _decoderToolsVec;
68 
69  template <class T>
71  vector<T>& vec,
72  size_t batchSize)
73  {
74  triton_input.setShape({static_cast<long int>(vec.size())});
75  triton_input.toServer(
76  std::make_shared<lartriton::TritonInput<T>>(lartriton::TritonInput<T>(batchSize, vec)));
77  }
78 };
79 
81  : EDProducer{p}
82  , minHits(p.get<size_t>("minHits"))
83  , debug(p.get<bool>("debug"))
84  , planes(p.get<vector<std::string>>("planes"))
85  , tritonPset(p.get<fhicl::ParameterSet>("TritonConfig"))
86 {
87 
88  // ... Create the Triton inference client
89  if (debug) std::cout << "TritonConfig: " << tritonPset.to_string() << std::endl;
90  triton_client = std::make_unique<lartriton::TritonClient>(tritonPset);
91 
92  // Loader Tool
93  _loaderTool = art::make_tool<LoaderToolBase>(p.get<fhicl::ParameterSet>("LoaderTool"));
94  _loaderTool->setDebugAndPlanes(debug, planes);
95 
96  // configure and construct Decoder Tools
97  auto const tool_psets = p.get<fhicl::ParameterSet>("DecoderTools");
98  for (auto const& tool_pset_labels : tool_psets.get_pset_names()) {
99  std::cout << "decoder lablel: " << tool_pset_labels << std::endl;
100  auto const tool_pset = tool_psets.get<fhicl::ParameterSet>(tool_pset_labels);
101  _decoderToolsVec.push_back(art::make_tool<DecoderToolBase>(tool_pset));
102  _decoderToolsVec.back()->setDebugAndPlanes(debug, planes);
103  _decoderToolsVec.back()->declareProducts(producesCollector());
104  }
105 }
106 
108 {
109 
110  //
111  // Load the data and fill the graph inputs
112  //
113  vector<art::Ptr<Hit>> hitlist;
114  vector<vector<size_t>> idsmap;
115  vector<NuGraphInput> graphinputs;
116  _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
117 
118  if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl;
119  if (hitlist.size() < minHits) {
120  // Writing the empty outputs to the output root file
121  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
122  _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap);
123  }
124  return;
125  }
126 
127  //
128  // NuSonic Triton Server section
129  //
130  auto start = std::chrono::high_resolution_clock::now();
131  //
132  //Here the input should be sent to Triton
133  triton_client->reset();
134  size_t batchSize = 1; //the code below assumes/has only been tested for batch size = 1
135  triton_client->setBatchSize(batchSize); // set batch size
136  //
137  auto& inputs = triton_client->input();
138  for (auto& input_pair : inputs) {
139  const std::string& key = input_pair.first;
140  auto& triton_input = input_pair.second;
141  //
142  for (auto& gi : graphinputs) {
143  if (key != gi.input_name) continue;
144  if (gi.isInt)
145  setShapeAndToServer(triton_input, gi.input_int32_vec, batchSize);
146  else
147  setShapeAndToServer(triton_input, gi.input_float_vec, batchSize);
148  }
149  }
150  // ~~~~ Send inference request
151  triton_client->dispatch();
152  // ~~~~ Retrieve inference results
153  auto& infer_result = triton_client->output();
154  auto end = std::chrono::high_resolution_clock::now();
155  std::chrono::duration<double> elapsed = end - start;
156  std::cout << "Time taken for inference: " << elapsed.count() << " seconds" << std::endl;
157 
158  //
159  // Get pointers to the result returned and write to the event
160  //
161  vector<NuGraphOutput> infer_output;
162  for (const auto& [name, data] : infer_result) {
163  const auto& prob = data.fromServer<float>();
164  std::vector<float> out_data(prob[0].begin(), prob[0].end());
165  infer_output.emplace_back(name, std::move(out_data));
166  }
167 
168  // Write the outputs to the output root file
169  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
170  _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output);
171  }
172 }
173 
bool setShape(const ShapeType &newShape)
Definition: TritonData.h:45
std::unique_ptr< LoaderToolBase > _loaderTool
Declaration of signal hit object.
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.cc:6
NuGraphInferenceSonicTriton & operator=(NuGraphInferenceSonicTriton const &)=delete
void toServer(std::shared_ptr< TritonInput< DT >> ptr)
Definition: TritonData.h:50
NuGraphInferenceSonicTriton(fhicl::ParameterSet const &p)
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
#define DEFINE_ART_MODULE(klass)
Definition: ModuleMacros.h:65
void setShapeAndToServer(lartriton::TritonData< triton::client::InferInput > &triton_input, vector< T > &vec, size_t batchSize)
std::vector< std::vector< DT >> TritonInput
Definition: TritonData.h:28
T get(std::string const &key) const
Definition: ParameterSet.h:314
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
ProducesCollector & producesCollector() noexcept
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
2D representation of charge deposited in the TDC/wire plane
Definition: Hit.h:46
Float_t e
Definition: plot.C:35
std::unique_ptr< lartriton::TritonClient > triton_client
std::string to_string() const
Definition: ParameterSet.h:196