LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
NuGraphInference_module.cc
Go to the documentation of this file.
1 // Class: NuGraphInference
3 // Plugin Type: producer (Unknown Unknown)
4 // File: NuGraphInference_module.cc
5 //
6 // Generated at Tue Nov 14 14:41:30 2023 by Giuseppe Cerati using cetskelgen
7 // from version .
9 
18 #include "fhiclcpp/ParameterSet.h"
20 
21 #include <array>
22 #include <limits>
23 #include <memory>
24 
25 #include "delaunator-header-only.hpp"
26 #include <torch/script.h>
27 
31 #include "lardataobj/RecoBase/Vertex.h" //this creates a conflict with torch script if included before it...
32 
35 
36 class NuGraphInference;
37 
40 using recob::Hit;
41 using recob::SpacePoint;
42 using std::array;
43 using std::vector;
44 
45 namespace {
46  template <typename T, typename A>
47  int arg_max(std::vector<T, A> const& vec)
48  {
49  return static_cast<int>(std::distance(vec.begin(), max_element(vec.begin(), vec.end())));
50  }
51 
52  template <typename T, size_t N>
53  void softmax(std::array<T, N>& arr)
54  {
55  T m = -std::numeric_limits<T>::max();
56  for (size_t i = 0; i < arr.size(); i++) {
57  if (arr[i] > m) { m = arr[i]; }
58  }
59  T sum = 0.0;
60  for (size_t i = 0; i < arr.size(); i++) {
61  sum += expf(arr[i] - m);
62  }
63  T offset = m + logf(sum);
64  for (size_t i = 0; i < arr.size(); i++) {
65  arr[i] = expf(arr[i] - offset);
66  }
67  return;
68  }
69 }
70 
72 public:
73  explicit NuGraphInference(fhicl::ParameterSet const& p);
74 
75  // Plugins should not be copied or assigned.
76  NuGraphInference(NuGraphInference const&) = delete;
80 
81  // Required functions.
82  void produce(art::Event& e) override;
83 
84 private:
85  vector<std::string> planes;
86  size_t minHits;
87  bool debug;
88  vector<vector<float>> avgs;
89  vector<vector<float>> devs;
90  vector<float> pos_norm;
91  torch::jit::script::Module model;
92  // loader tool
93  std::unique_ptr<LoaderToolBase> _loaderTool;
94  // decoder tools
95  std::vector<std::unique_ptr<DecoderToolBase>> _decoderToolsVec;
96 };
97 
99  : EDProducer{p}
100  , planes(p.get<vector<std::string>>("planes"))
101  , minHits(p.get<size_t>("minHits"))
102  , debug(p.get<bool>("debug"))
103  , pos_norm(p.get<vector<float>>("pos_norm"))
104 {
105 
106  for (size_t ip = 0; ip < planes.size(); ++ip) {
107  avgs.push_back(p.get<vector<float>>("avgs_" + planes[ip]));
108  devs.push_back(p.get<vector<float>>("devs_" + planes[ip]));
109  }
110 
111  // Loader Tool
112  _loaderTool = art::make_tool<LoaderToolBase>(p.get<fhicl::ParameterSet>("LoaderTool"));
113  _loaderTool->setDebugAndPlanes(debug, planes);
114 
115  // configure and construct Decoder Tools
116  auto const tool_psets = p.get<fhicl::ParameterSet>("DecoderTools");
117  for (auto const& tool_pset_labels : tool_psets.get_pset_names()) {
118  std::cout << "decoder lablel: " << tool_pset_labels << std::endl;
119  auto const tool_pset = tool_psets.get<fhicl::ParameterSet>(tool_pset_labels);
120  _decoderToolsVec.push_back(art::make_tool<DecoderToolBase>(tool_pset));
121  _decoderToolsVec.back()->setDebugAndPlanes(debug, planes);
122  _decoderToolsVec.back()->declareProducts(producesCollector());
123  }
124 
125  cet::search_path sp("FW_SEARCH_PATH");
126  model = torch::jit::load(sp.find_file(p.get<std::string>("modelFileName")));
127 }
128 
130 {
131 
132  //
133  // Load the data and fill the graph inputs
134  //
135  vector<art::Ptr<Hit>> hitlist;
136  vector<vector<size_t>> idsmap;
137  vector<NuGraphInput> graphinputs;
138  _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
139 
140  if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl;
141  if (hitlist.size() < minHits) {
142  // Writing the empty outputs to the output root file
143  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
144  _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap);
145  }
146  return;
147  }
148 
149  //
150  // libTorch-specific section: requires extracting inputs, create graph, run inference
151  //
152  const vector<int32_t>* spids = nullptr;
153  const vector<int32_t>* hitids_u = nullptr;
154  const vector<int32_t>* hitids_v = nullptr;
155  const vector<int32_t>* hitids_y = nullptr;
156  const vector<int32_t>* hit_plane = nullptr;
157  const vector<float>* hit_time = nullptr;
158  const vector<int32_t>* hit_wire = nullptr;
159  const vector<float>* hit_integral = nullptr;
160  const vector<float>* hit_rms = nullptr;
161  for (const auto& gi : graphinputs) {
162  if (gi.input_name == "spacepoint_table_spacepoint_id")
163  spids = &gi.input_int32_vec;
164  else if (gi.input_name == "spacepoint_table_hit_id_u")
165  hitids_u = &gi.input_int32_vec;
166  else if (gi.input_name == "spacepoint_table_hit_id_v")
167  hitids_v = &gi.input_int32_vec;
168  else if (gi.input_name == "spacepoint_table_hit_id_y")
169  hitids_y = &gi.input_int32_vec;
170  else if (gi.input_name == "hit_table_local_plane")
171  hit_plane = &gi.input_int32_vec;
172  else if (gi.input_name == "hit_table_local_time")
173  hit_time = &gi.input_float_vec;
174  else if (gi.input_name == "hit_table_local_wire")
175  hit_wire = &gi.input_int32_vec;
176  else if (gi.input_name == "hit_table_integral")
177  hit_integral = &gi.input_float_vec;
178  else if (gi.input_name == "hit_table_rms")
179  hit_rms = &gi.input_float_vec;
180  }
181 
182  // Reverse lookup from key to index in plane index
183  vector<size_t> idsmapRev(hitlist.size(), hitlist.size());
184  for (const auto& ipv : idsmap) {
185  for (size_t ih = 0; ih < ipv.size(); ih++) {
186  idsmapRev[ipv[ih]] = ih;
187  }
188  }
189 
190  struct Edge {
191  size_t n1;
192  size_t n2;
193  bool operator==(const Edge& other) const
194  {
195  if (this->n1 == other.n1 && this->n2 == other.n2)
196  return true;
197  else
198  return false;
199  };
200  };
201 
202  // Delauney graph construction
203  auto start_preprocess1 = std::chrono::high_resolution_clock::now();
204  vector<vector<Edge>> edge2d(planes.size(), vector<Edge>());
205  for (size_t p = 0; p < planes.size(); p++) {
206  vector<double> coords;
207  for (size_t i = 0; i < hit_plane->size(); ++i) {
208  if (size_t(hit_plane->at(i)) != p) continue;
209  coords.push_back(hit_time->at(i) * pos_norm[1]);
210  coords.push_back(hit_wire->at(i) * pos_norm[0]);
211  }
212  if (debug) std::cout << "Plane " << p << " has N hits=" << coords.size() / 2 << std::endl;
213  if (coords.size() / 2 < 3) { continue; }
214  delaunator::Delaunator d(coords);
215  if (debug) std::cout << "Found N triangles=" << d.triangles.size() / 3 << std::endl;
216  for (std::size_t i = 0; i < d.triangles.size(); i += 3) {
217  //create edges in both directions
218  Edge e;
219  e.n1 = d.triangles[i];
220  e.n2 = d.triangles[i + 1];
221  edge2d[p].push_back(e);
222  e.n1 = d.triangles[i + 1];
223  e.n2 = d.triangles[i];
224  edge2d[p].push_back(e);
225  //
226  e.n1 = d.triangles[i];
227  e.n2 = d.triangles[i + 2];
228  edge2d[p].push_back(e);
229  e.n1 = d.triangles[i + 2];
230  e.n2 = d.triangles[i];
231  edge2d[p].push_back(e);
232  //
233  e.n1 = d.triangles[i + 1];
234  e.n2 = d.triangles[i + 2];
235  edge2d[p].push_back(e);
236  e.n1 = d.triangles[i + 2];
237  e.n2 = d.triangles[i + 1];
238  edge2d[p].push_back(e);
239  //
240  }
241  //sort and cleanup duplicate edges
242  std::sort(edge2d[p].begin(), edge2d[p].end(), [](const auto& i, const auto& j) {
243  return (i.n1 != j.n1 ? i.n1 < j.n1 : i.n2 < j.n2);
244  });
245  if (debug) {
246  for (auto& e : edge2d[p]) {
247  std::cout << "sorted plane=" << p << " e1=" << e.n1 << " e2=" << e.n2 << std::endl;
248  }
249  }
250  edge2d[p].erase(std::unique(edge2d[p].begin(), edge2d[p].end()), edge2d[p].end());
251  }
252 
253  if (debug) {
254  for (size_t p = 0; p < planes.size(); p++) {
255  for (auto& e : edge2d[p]) {
256  std::cout << " plane=" << p << " e1=" << e.n1 << " e2=" << e.n2 << std::endl;
257  }
258  }
259  }
260  auto end_preprocess1 = std::chrono::high_resolution_clock::now();
261  std::chrono::duration<double> elapsed_preprocess1 = end_preprocess1 - start_preprocess1;
262 
263  // Nexus edges
264  auto start_preprocess2 = std::chrono::high_resolution_clock::now();
265  vector<vector<Edge>> edge3d(planes.size(), vector<Edge>());
266  for (size_t i = 0; i < spids->size(); ++i) {
267  if (hitids_u->at(i) >= 0) {
268  Edge e;
269  e.n1 = idsmapRev[hitids_u->at(i)];
270  e.n2 = spids->at(i);
271  edge3d[0].push_back(e);
272  }
273  if (hitids_v->at(i) >= 0) {
274  Edge e;
275  e.n1 = idsmapRev[hitids_v->at(i)];
276  e.n2 = spids->at(i);
277  edge3d[1].push_back(e);
278  }
279  if (hitids_y->at(i) >= 0) {
280  Edge e;
281  e.n1 = idsmapRev[hitids_y->at(i)];
282  e.n2 = spids->at(i);
283  edge3d[2].push_back(e);
284  }
285  }
286 
287  // Prepare inputs
288  auto x = torch::Dict<std::string, torch::Tensor>();
289  auto batch = torch::Dict<std::string, torch::Tensor>();
290  for (size_t p = 0; p < planes.size(); p++) {
291  vector<float> nodeft;
292  for (size_t i = 0; i < hit_plane->size(); ++i) {
293  if (size_t(hit_plane->at(i)) != p) continue;
294  nodeft.push_back((hit_wire->at(i) * pos_norm[0] - avgs[hit_plane->at(i)][0]) /
295  devs[hit_plane->at(i)][0]);
296  nodeft.push_back((hit_time->at(i) * pos_norm[1] - avgs[hit_plane->at(i)][1]) /
297  devs[hit_plane->at(i)][1]);
298  nodeft.push_back((hit_integral->at(i) - avgs[hit_plane->at(i)][2]) /
299  devs[hit_plane->at(i)][2]);
300  nodeft.push_back((hit_rms->at(i) - avgs[hit_plane->at(i)][3]) / devs[hit_plane->at(i)][3]);
301  }
302  long int dim = nodeft.size() / 4;
303  torch::Tensor ix = torch::zeros({dim, 4}, torch::dtype(torch::kFloat32));
304  if (debug) {
305  std::cout << "plane=" << p << std::endl;
306  std::cout << std::scientific;
307  for (size_t n = 0; n < nodeft.size(); n = n + 4) {
308  std::cout << nodeft[n] << " " << nodeft[n + 1] << " " << nodeft[n + 2] << " "
309  << nodeft[n + 3] << " " << std::endl;
310  }
311  }
312  for (size_t n = 0; n < nodeft.size(); n = n + 4) {
313  ix[n / 4][0] = nodeft[n];
314  ix[n / 4][1] = nodeft[n + 1];
315  ix[n / 4][2] = nodeft[n + 2];
316  ix[n / 4][3] = nodeft[n + 3];
317  }
318  x.insert(planes[p], ix);
319  torch::Tensor ib = torch::zeros({dim}, torch::dtype(torch::kInt64));
320  batch.insert(planes[p], ib);
321  }
322 
323  auto edge_index_plane = torch::Dict<std::string, torch::Tensor>();
324  for (size_t p = 0; p < planes.size(); p++) {
325  long int dim = edge2d[p].size();
326  torch::Tensor ix = torch::zeros({2, dim}, torch::dtype(torch::kInt64));
327  for (size_t n = 0; n < edge2d[p].size(); n++) {
328  ix[0][n] = int(edge2d[p][n].n1);
329  ix[1][n] = int(edge2d[p][n].n2);
330  }
331  edge_index_plane.insert(planes[p], ix);
332  if (debug) {
333  std::cout << "plane=" << p << std::endl;
334  std::cout << "2d edge size=" << edge2d[p].size() << std::endl;
335  for (size_t n = 0; n < edge2d[p].size(); n++) {
336  std::cout << edge2d[p][n].n1 << " ";
337  }
338  std::cout << std::endl;
339  for (size_t n = 0; n < edge2d[p].size(); n++) {
340  std::cout << edge2d[p][n].n2 << " ";
341  }
342  std::cout << std::endl;
343  }
344  }
345 
346  auto edge_index_nexus = torch::Dict<std::string, torch::Tensor>();
347  for (size_t p = 0; p < planes.size(); p++) {
348  long int dim = edge3d[p].size();
349  torch::Tensor ix = torch::zeros({2, dim}, torch::dtype(torch::kInt64));
350  for (size_t n = 0; n < edge3d[p].size(); n++) {
351  ix[0][n] = int(edge3d[p][n].n1);
352  ix[1][n] = int(edge3d[p][n].n2);
353  }
354  edge_index_nexus.insert(planes[p], ix);
355  if (debug) {
356  std::cout << "plane=" << p << std::endl;
357  std::cout << "3d edge size=" << edge3d[p].size() << std::endl;
358  for (size_t n = 0; n < edge3d[p].size(); n++) {
359  std::cout << edge3d[p][n].n1 << " ";
360  }
361  std::cout << std::endl;
362  for (size_t n = 0; n < edge3d[p].size(); n++) {
363  std::cout << edge3d[p][n].n2 << " ";
364  }
365  std::cout << std::endl;
366  }
367  }
368 
369  long int spdim = spids->size();
370  auto nexus = torch::empty({spdim, 0}, torch::dtype(torch::kFloat32));
371 
372  std::vector<torch::jit::IValue> inputs;
373  inputs.push_back(x);
374  inputs.push_back(edge_index_plane);
375  inputs.push_back(edge_index_nexus);
376  inputs.push_back(nexus);
377  inputs.push_back(batch);
378 
379  // Run inference
380  auto end_preprocess2 = std::chrono::high_resolution_clock::now();
381  std::chrono::duration<double> elapsed_preprocess2 = end_preprocess2 - start_preprocess2;
382  if (debug) std::cout << "FORWARD!" << std::endl;
383  auto start = std::chrono::high_resolution_clock::now();
384  auto outputs = model.forward(inputs).toGenericDict();
385  auto end = std::chrono::high_resolution_clock::now();
386  std::chrono::duration<double> elapsed = end - start;
387  if (debug) {
388  std::cout << "Time taken for inference: "
389  << elapsed_preprocess1.count() + elapsed_preprocess2.count() + elapsed.count()
390  << " seconds" << std::endl;
391  std::cout << "output =" << outputs << std::endl;
392  }
393 
394  //
395  // Get pointers to the result returned and write to the event
396  //
397  vector<NuGraphOutput> infer_output;
398  for (const auto& elem1 : outputs) {
399  if (elem1.value().isTensor()) {
400  torch::Tensor tensor = elem1.value().toTensor();
401  std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
402  infer_output.push_back(NuGraphOutput(elem1.key().to<std::string>(), vec));
403  }
404  else if (elem1.value().isGenericDict()) {
405  for (const auto& elem2 : elem1.value().toGenericDict()) {
406  torch::Tensor tensor = elem2.value().toTensor();
407  std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
408  infer_output.push_back(
409  NuGraphOutput(elem1.key().to<std::string>() + "_" + elem2.key().to<std::string>(), vec));
410  }
411  }
412  }
413 
414  // Write the outputs to the output root file
415  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
416  _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output);
417  }
418 }
419 
Float_t x
Definition: compare.C:6
NuGraphInference & operator=(NuGraphInference const &)=delete
torch::jit::script::Module model
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
Declaration of signal hit object.
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.cc:6
vector< std::string > planes
boost::graph_traits< ModuleGraph >::edge_descriptor Edge
Definition: ModuleGraph.h:24
void produce(art::Event &e) override
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
auto array(Array const &a)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:250
#define DEFINE_ART_MODULE(klass)
Definition: ModuleMacros.h:65
T get(std::string const &key) const
Definition: ParameterSet.h:314
vector< vector< float > > devs
Float_t d
Definition: plot.C:235
vector< vector< float > > avgs
ProducesCollector & producesCollector() noexcept
Char_t n[5]
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
NuGraphInference(fhicl::ParameterSet const &p)
2D representation of charge deposited in the TDC/wire plane
Definition: Hit.h:46
Double_t sum
Definition: plot.C:31
Float_t e
Definition: plot.C:35
bool operator==(infinite_endcount_iterator< T > const &, count_iterator< T > const &)
Definition: counter.h:277
decltype(auto) constexpr empty(T &&obj)
ADL-aware version of std::empty.
Definition: StdUtils.h:109
std::unique_ptr< LoaderToolBase > _loaderTool