LArSoft  v10_06_00
Liquid Argon Software toolkit - https://larsoft.org/
NuGraphInferenceSonicTriton_module.cc
Go to the documentation of this file.
1 // Class: NuGraphInferenceSonicTriton
3 // Plugin Type: producer (Unknown Unknown)
4 // File: NuGraphInferenceSonicTriton_module.cc
5 //
6 // Generated at Tue Nov 14 14:41:30 2023 by Giuseppe Cerati using cetskelgen
7 // from version .
9 
15 #include "fhiclcpp/ParameterSet.h"
16 #include "fhiclcpp/types/Table.h"
18 
19 #include <limits>
20 #include <memory>
21 
27 
30 
31 #include <chrono>
32 #include <fstream>
33 #include <iostream>
34 #include <sstream>
35 #include <string>
36 #include <vector>
37 
39 
42 using recob::Hit;
43 using std::vector;
44 
46 public:
48 
49  // Plugins should not be copied or assigned.
54 
55  // Required functions.
56  void produce(art::Event& e) override;
57 
58 private:
59  size_t minHits;
60  bool debug;
61  vector<std::string> planes;
63  std::unique_ptr<lartriton::TritonClient> triton_client;
64 
65  // loader tool
66  std::unique_ptr<LoaderToolBase> _loaderTool;
67  // decoder tools
68  std::vector<std::unique_ptr<DecoderToolBase>> _decoderToolsVec;
69 
70  template <class T>
72  vector<T>& vec,
73  size_t batchSize)
74  {
75  triton_input.setShape({static_cast<long int>(vec.size())});
76  triton_input.toServer(
77  std::make_shared<lartriton::TritonInput<T>>(lartriton::TritonInput<T>(batchSize, vec)));
78  }
79 };
80 
82  : EDProducer{p}
83  , minHits(p.get<size_t>("minHits"))
84  , debug(p.get<bool>("debug"))
85  , planes(p.get<vector<std::string>>("planes"))
86  , tritonPset(p.get<fhicl::ParameterSet>("TritonConfig"))
87 {
88 
89  // ... Create the Triton inference client
90  if (debug) std::cout << "TritonConfig: " << tritonPset.to_string() << std::endl;
91  triton_client = std::make_unique<lartriton::TritonClient>(tritonPset);
92 
93  // Loader Tool
94  _loaderTool = art::make_tool<LoaderToolBase>(p.get<fhicl::ParameterSet>("LoaderTool"));
95  _loaderTool->setDebugAndPlanes(debug, planes);
96 
97  // configure and construct Decoder Tools
98  auto const tool_psets = p.get<fhicl::ParameterSet>("DecoderTools");
99  for (auto const& tool_pset_labels : tool_psets.get_pset_names()) {
100  std::cout << "decoder lablel: " << tool_pset_labels << std::endl;
101  auto const tool_pset = tool_psets.get<fhicl::ParameterSet>(tool_pset_labels);
102  _decoderToolsVec.push_back(art::make_tool<DecoderToolBase>(tool_pset));
103  _decoderToolsVec.back()->setDebugAndPlanes(debug, planes);
104  _decoderToolsVec.back()->declareProducts(producesCollector());
105  }
106 }
107 
109 {
110 
111  //
112  // Load the data and fill the graph inputs
113  //
114  vector<art::Ptr<Hit>> hitlist;
115  vector<vector<size_t>> idsmap;
116  vector<NuGraphInput> graphinputs;
117  _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
118 
119  if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl;
120  if (hitlist.size() < minHits) {
121  // Writing the empty outputs to the output root file
122  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
123  _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap);
124  }
125  return;
126  }
127 
128  //
129  // NuSonic Triton Server section
130  //
131  auto start = std::chrono::high_resolution_clock::now();
132  //
133  //Here the input should be sent to Triton
134  triton_client->reset();
135  size_t batchSize = 1; //the code below assumes/has only been tested for batch size = 1
136  triton_client->setBatchSize(batchSize); // set batch size
137  //
138  auto& inputs = triton_client->input();
139  for (auto& input_pair : inputs) {
140  const std::string& key = input_pair.first;
141  auto& triton_input = input_pair.second;
142  //
143  for (auto& gi : graphinputs) {
144  if (key != gi.input_name) continue;
145  if (gi.isInt)
146  setShapeAndToServer(triton_input, gi.input_int32_vec, batchSize);
147  else
148  setShapeAndToServer(triton_input, gi.input_float_vec, batchSize);
149  }
150  }
151  // ~~~~ Send inference request
152  triton_client->dispatch();
153  // ~~~~ Retrieve inference results
154  auto& infer_result = triton_client->output();
155  auto end = std::chrono::high_resolution_clock::now();
156  std::chrono::duration<double> elapsed = end - start;
157  std::cout << "Time taken for inference: " << elapsed.count() << " seconds" << std::endl;
158 
159  //
160  // Get pointers to the result returned and write to the event
161  //
162  vector<NuGraphOutput> infer_output;
163  for (const auto& [name, data] : infer_result) {
164  const auto& prob = data.fromServer<float>();
165  std::vector<float> out_data(prob[0].begin(), prob[0].end());
166  infer_output.emplace_back(name, std::move(out_data));
167  }
168 
169  // Write the outputs to the output root file
170  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
171  _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output);
172  }
173 }
174 
bool setShape(const ShapeType &newShape)
Definition: TritonData.h:45
std::unique_ptr< LoaderToolBase > _loaderTool
Declaration of signal hit object.
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.cc:6
NuGraphInferenceSonicTriton & operator=(NuGraphInferenceSonicTriton const &)=delete
void toServer(std::shared_ptr< TritonInput< DT >> ptr)
Definition: TritonData.h:50
NuGraphInferenceSonicTriton(fhicl::ParameterSet const &p)
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
#define DEFINE_ART_MODULE(klass)
Definition: ModuleMacros.h:65
void setShapeAndToServer(lartriton::TritonData< triton::client::InferInput > &triton_input, vector< T > &vec, size_t batchSize)
std::vector< std::vector< DT >> TritonInput
Definition: TritonData.h:28
T get(std::string const &key) const
Definition: ParameterSet.h:314
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
ProducesCollector & producesCollector() noexcept
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
2D representation of charge deposited in the TDC/wire plane
Definition: Hit.h:46
Float_t e
Definition: plot.C:35
std::unique_ptr< lartriton::TritonClient > triton_client
std::string to_string() const
Definition: ParameterSet.h:196