LArSoft  v10_06_00
Liquid Argon Software toolkit - https://larsoft.org/
NuGraphInference_module.cc
Go to the documentation of this file.
1 // Class: NuGraphInference
3 // Plugin Type: producer (Unknown Unknown)
4 // File: NuGraphInference_module.cc
5 //
6 // Generated at Tue Nov 14 14:41:30 2023 by Giuseppe Cerati using cetskelgen
7 // from version .
9 
19 #include "fhiclcpp/ParameterSet.h"
21 
22 #include <array>
23 #include <limits>
24 #include <memory>
25 
26 #include "delaunator-header-only.hpp"
27 #include <torch/script.h>
28 
32 #include "lardataobj/RecoBase/Vertex.h" //this creates a conflict with torch script if included before it...
33 
36 
37 class NuGraphInference;
38 
41 using recob::Hit;
42 using recob::SpacePoint;
43 using std::array;
44 using std::vector;
45 
46 namespace {
47  template <typename T, typename A>
48  int arg_max(std::vector<T, A> const& vec)
49  {
50  return static_cast<int>(std::distance(vec.begin(), max_element(vec.begin(), vec.end())));
51  }
52 
53  template <typename T, size_t N>
54  void softmax(std::array<T, N>& arr)
55  {
56  T m = -std::numeric_limits<T>::max();
57  for (size_t i = 0; i < arr.size(); i++) {
58  if (arr[i] > m) { m = arr[i]; }
59  }
60  T sum = 0.0;
61  for (size_t i = 0; i < arr.size(); i++) {
62  sum += expf(arr[i] - m);
63  }
64  T offset = m + logf(sum);
65  for (size_t i = 0; i < arr.size(); i++) {
66  arr[i] = expf(arr[i] - offset);
67  }
68  return;
69  }
70 }
71 
73 public:
74  explicit NuGraphInference(fhicl::ParameterSet const& p);
75 
76  // Plugins should not be copied or assigned.
77  NuGraphInference(NuGraphInference const&) = delete;
81 
82  // Required functions.
83  void produce(art::Event& e) override;
84 
85 private:
86  vector<std::string> planes;
87  size_t minHits;
88  bool debug;
89  vector<vector<float>> avgs;
90  vector<vector<float>> devs;
91  vector<float> pos_norm;
92  torch::jit::script::Module model;
93  // loader tool
94  std::unique_ptr<LoaderToolBase> _loaderTool;
95  // decoder tools
96  std::vector<std::unique_ptr<DecoderToolBase>> _decoderToolsVec;
97 };
98 
100  : EDProducer{p}
101  , planes(p.get<vector<std::string>>("planes"))
102  , minHits(p.get<size_t>("minHits"))
103  , debug(p.get<bool>("debug"))
104  , pos_norm(p.get<vector<float>>("pos_norm"))
105 {
106 
107  for (size_t ip = 0; ip < planes.size(); ++ip) {
108  avgs.push_back(p.get<vector<float>>("avgs_" + planes[ip]));
109  devs.push_back(p.get<vector<float>>("devs_" + planes[ip]));
110  }
111 
112  // Loader Tool
113  _loaderTool = art::make_tool<LoaderToolBase>(p.get<fhicl::ParameterSet>("LoaderTool"));
114  _loaderTool->setDebugAndPlanes(debug, planes);
115 
116  // configure and construct Decoder Tools
117  auto const tool_psets = p.get<fhicl::ParameterSet>("DecoderTools");
118  for (auto const& tool_pset_labels : tool_psets.get_pset_names()) {
119  std::cout << "decoder lablel: " << tool_pset_labels << std::endl;
120  auto const tool_pset = tool_psets.get<fhicl::ParameterSet>(tool_pset_labels);
121  _decoderToolsVec.push_back(art::make_tool<DecoderToolBase>(tool_pset));
122  _decoderToolsVec.back()->setDebugAndPlanes(debug, planes);
123  _decoderToolsVec.back()->declareProducts(producesCollector());
124  }
125 
126  cet::search_path sp("FW_SEARCH_PATH");
127  model = torch::jit::load(sp.find_file(p.get<std::string>("modelFileName")));
128 }
129 
131 {
132 
133  //
134  // Load the data and fill the graph inputs
135  //
136  vector<art::Ptr<Hit>> hitlist;
137  vector<vector<size_t>> idsmap;
138  vector<NuGraphInput> graphinputs;
139  _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
140 
141  if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl;
142  if (hitlist.size() < minHits) {
143  // Writing the empty outputs to the output root file
144  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
145  _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap);
146  }
147  return;
148  }
149 
150  //
151  // libTorch-specific section: requires extracting inputs, create graph, run inference
152  //
153  const vector<int32_t>* spids = nullptr;
154  const vector<int32_t>* hitids_u = nullptr;
155  const vector<int32_t>* hitids_v = nullptr;
156  const vector<int32_t>* hitids_y = nullptr;
157  const vector<int32_t>* hit_plane = nullptr;
158  const vector<float>* hit_time = nullptr;
159  const vector<int32_t>* hit_wire = nullptr;
160  const vector<float>* hit_integral = nullptr;
161  const vector<float>* hit_rms = nullptr;
162  for (const auto& gi : graphinputs) {
163  if (gi.input_name == "spacepoint_table_spacepoint_id")
164  spids = &gi.input_int32_vec;
165  else if (gi.input_name == "spacepoint_table_hit_id_u")
166  hitids_u = &gi.input_int32_vec;
167  else if (gi.input_name == "spacepoint_table_hit_id_v")
168  hitids_v = &gi.input_int32_vec;
169  else if (gi.input_name == "spacepoint_table_hit_id_y")
170  hitids_y = &gi.input_int32_vec;
171  else if (gi.input_name == "hit_table_local_plane")
172  hit_plane = &gi.input_int32_vec;
173  else if (gi.input_name == "hit_table_local_time")
174  hit_time = &gi.input_float_vec;
175  else if (gi.input_name == "hit_table_local_wire")
176  hit_wire = &gi.input_int32_vec;
177  else if (gi.input_name == "hit_table_integral")
178  hit_integral = &gi.input_float_vec;
179  else if (gi.input_name == "hit_table_rms")
180  hit_rms = &gi.input_float_vec;
181  }
182 
183  // Reverse lookup from key to index in plane index
184  vector<size_t> idsmapRev(hitlist.size(), hitlist.size());
185  for (const auto& ipv : idsmap) {
186  for (size_t ih = 0; ih < ipv.size(); ih++) {
187  idsmapRev[ipv[ih]] = ih;
188  }
189  }
190 
191  struct Edge {
192  size_t n1;
193  size_t n2;
194  bool operator==(const Edge& other) const
195  {
196  if (this->n1 == other.n1 && this->n2 == other.n2)
197  return true;
198  else
199  return false;
200  };
201  };
202 
203  // Delauney graph construction
204  auto start_preprocess1 = std::chrono::high_resolution_clock::now();
205  vector<vector<Edge>> edge2d(planes.size(), vector<Edge>());
206  for (size_t p = 0; p < planes.size(); p++) {
207  vector<double> coords;
208  for (size_t i = 0; i < hit_plane->size(); ++i) {
209  if (size_t(hit_plane->at(i)) != p) continue;
210  coords.push_back(hit_time->at(i) * pos_norm[1]);
211  coords.push_back(hit_wire->at(i) * pos_norm[0]);
212  }
213  if (debug) std::cout << "Plane " << p << " has N hits=" << coords.size() / 2 << std::endl;
214  if (coords.size() / 2 < 3) { continue; }
215  delaunator::Delaunator d(coords);
216  if (debug) std::cout << "Found N triangles=" << d.triangles.size() / 3 << std::endl;
217  for (std::size_t i = 0; i < d.triangles.size(); i += 3) {
218  //create edges in both directions
219  Edge e;
220  e.n1 = d.triangles[i];
221  e.n2 = d.triangles[i + 1];
222  edge2d[p].push_back(e);
223  e.n1 = d.triangles[i + 1];
224  e.n2 = d.triangles[i];
225  edge2d[p].push_back(e);
226  //
227  e.n1 = d.triangles[i];
228  e.n2 = d.triangles[i + 2];
229  edge2d[p].push_back(e);
230  e.n1 = d.triangles[i + 2];
231  e.n2 = d.triangles[i];
232  edge2d[p].push_back(e);
233  //
234  e.n1 = d.triangles[i + 1];
235  e.n2 = d.triangles[i + 2];
236  edge2d[p].push_back(e);
237  e.n1 = d.triangles[i + 2];
238  e.n2 = d.triangles[i + 1];
239  edge2d[p].push_back(e);
240  //
241  }
242  //sort and cleanup duplicate edges
243  std::sort(edge2d[p].begin(), edge2d[p].end(), [](const auto& i, const auto& j) {
244  return (i.n1 != j.n1 ? i.n1 < j.n1 : i.n2 < j.n2);
245  });
246  if (debug) {
247  for (auto& e : edge2d[p]) {
248  std::cout << "sorted plane=" << p << " e1=" << e.n1 << " e2=" << e.n2 << std::endl;
249  }
250  }
251  edge2d[p].erase(std::unique(edge2d[p].begin(), edge2d[p].end()), edge2d[p].end());
252  }
253 
254  if (debug) {
255  for (size_t p = 0; p < planes.size(); p++) {
256  for (auto& e : edge2d[p]) {
257  std::cout << " plane=" << p << " e1=" << e.n1 << " e2=" << e.n2 << std::endl;
258  }
259  }
260  }
261  auto end_preprocess1 = std::chrono::high_resolution_clock::now();
262  std::chrono::duration<double> elapsed_preprocess1 = end_preprocess1 - start_preprocess1;
263 
264  // Nexus edges
265  auto start_preprocess2 = std::chrono::high_resolution_clock::now();
266  vector<vector<Edge>> edge3d(planes.size(), vector<Edge>());
267  for (size_t i = 0; i < spids->size(); ++i) {
268  if (hitids_u->at(i) >= 0) {
269  Edge e;
270  e.n1 = idsmapRev[hitids_u->at(i)];
271  e.n2 = spids->at(i);
272  edge3d[0].push_back(e);
273  }
274  if (hitids_v->at(i) >= 0) {
275  Edge e;
276  e.n1 = idsmapRev[hitids_v->at(i)];
277  e.n2 = spids->at(i);
278  edge3d[1].push_back(e);
279  }
280  if (hitids_y->at(i) >= 0) {
281  Edge e;
282  e.n1 = idsmapRev[hitids_y->at(i)];
283  e.n2 = spids->at(i);
284  edge3d[2].push_back(e);
285  }
286  }
287 
288  // Prepare inputs
289  auto x = torch::Dict<std::string, torch::Tensor>();
290  auto batch = torch::Dict<std::string, torch::Tensor>();
291  for (size_t p = 0; p < planes.size(); p++) {
292  vector<float> nodeft;
293  for (size_t i = 0; i < hit_plane->size(); ++i) {
294  if (size_t(hit_plane->at(i)) != p) continue;
295  nodeft.push_back((hit_wire->at(i) * pos_norm[0] - avgs[hit_plane->at(i)][0]) /
296  devs[hit_plane->at(i)][0]);
297  nodeft.push_back((hit_time->at(i) * pos_norm[1] - avgs[hit_plane->at(i)][1]) /
298  devs[hit_plane->at(i)][1]);
299  nodeft.push_back((hit_integral->at(i) - avgs[hit_plane->at(i)][2]) /
300  devs[hit_plane->at(i)][2]);
301  nodeft.push_back((hit_rms->at(i) - avgs[hit_plane->at(i)][3]) / devs[hit_plane->at(i)][3]);
302  }
303  long int dim = nodeft.size() / 4;
304  torch::Tensor ix = torch::zeros({dim, 4}, torch::dtype(torch::kFloat32));
305  if (debug) {
306  std::cout << "plane=" << p << std::endl;
307  std::cout << std::scientific;
308  for (size_t n = 0; n < nodeft.size(); n = n + 4) {
309  std::cout << nodeft[n] << " " << nodeft[n + 1] << " " << nodeft[n + 2] << " "
310  << nodeft[n + 3] << " " << std::endl;
311  }
312  }
313  for (size_t n = 0; n < nodeft.size(); n = n + 4) {
314  ix[n / 4][0] = nodeft[n];
315  ix[n / 4][1] = nodeft[n + 1];
316  ix[n / 4][2] = nodeft[n + 2];
317  ix[n / 4][3] = nodeft[n + 3];
318  }
319  x.insert(planes[p], ix);
320  torch::Tensor ib = torch::zeros({dim}, torch::dtype(torch::kInt64));
321  batch.insert(planes[p], ib);
322  }
323 
324  auto edge_index_plane = torch::Dict<std::string, torch::Tensor>();
325  for (size_t p = 0; p < planes.size(); p++) {
326  long int dim = edge2d[p].size();
327  torch::Tensor ix = torch::zeros({2, dim}, torch::dtype(torch::kInt64));
328  for (size_t n = 0; n < edge2d[p].size(); n++) {
329  ix[0][n] = int(edge2d[p][n].n1);
330  ix[1][n] = int(edge2d[p][n].n2);
331  }
332  edge_index_plane.insert(planes[p], ix);
333  if (debug) {
334  std::cout << "plane=" << p << std::endl;
335  std::cout << "2d edge size=" << edge2d[p].size() << std::endl;
336  for (size_t n = 0; n < edge2d[p].size(); n++) {
337  std::cout << edge2d[p][n].n1 << " ";
338  }
339  std::cout << std::endl;
340  for (size_t n = 0; n < edge2d[p].size(); n++) {
341  std::cout << edge2d[p][n].n2 << " ";
342  }
343  std::cout << std::endl;
344  }
345  }
346 
347  auto edge_index_nexus = torch::Dict<std::string, torch::Tensor>();
348  for (size_t p = 0; p < planes.size(); p++) {
349  long int dim = edge3d[p].size();
350  torch::Tensor ix = torch::zeros({2, dim}, torch::dtype(torch::kInt64));
351  for (size_t n = 0; n < edge3d[p].size(); n++) {
352  ix[0][n] = int(edge3d[p][n].n1);
353  ix[1][n] = int(edge3d[p][n].n2);
354  }
355  edge_index_nexus.insert(planes[p], ix);
356  if (debug) {
357  std::cout << "plane=" << p << std::endl;
358  std::cout << "3d edge size=" << edge3d[p].size() << std::endl;
359  for (size_t n = 0; n < edge3d[p].size(); n++) {
360  std::cout << edge3d[p][n].n1 << " ";
361  }
362  std::cout << std::endl;
363  for (size_t n = 0; n < edge3d[p].size(); n++) {
364  std::cout << edge3d[p][n].n2 << " ";
365  }
366  std::cout << std::endl;
367  }
368  }
369 
370  long int spdim = spids->size();
371  auto nexus = torch::empty({spdim, 0}, torch::dtype(torch::kFloat32));
372 
373  std::vector<torch::jit::IValue> inputs;
374  inputs.push_back(x);
375  inputs.push_back(edge_index_plane);
376  inputs.push_back(edge_index_nexus);
377  inputs.push_back(nexus);
378  inputs.push_back(batch);
379 
380  // Run inference
381  auto end_preprocess2 = std::chrono::high_resolution_clock::now();
382  std::chrono::duration<double> elapsed_preprocess2 = end_preprocess2 - start_preprocess2;
383  if (debug) std::cout << "FORWARD!" << std::endl;
384  auto start = std::chrono::high_resolution_clock::now();
385  auto outputs = model.forward(inputs).toGenericDict();
386  auto end = std::chrono::high_resolution_clock::now();
387  std::chrono::duration<double> elapsed = end - start;
388  if (debug) {
389  std::cout << "Time taken for inference: "
390  << elapsed_preprocess1.count() + elapsed_preprocess2.count() + elapsed.count()
391  << " seconds" << std::endl;
392  std::cout << "output =" << outputs << std::endl;
393  }
394 
395  //
396  // Get pointers to the result returned and write to the event
397  //
398  vector<NuGraphOutput> infer_output;
399  for (const auto& elem1 : outputs) {
400  if (elem1.value().isTensor()) {
401  torch::Tensor tensor = elem1.value().toTensor();
402  std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
403  infer_output.push_back(NuGraphOutput(elem1.key().to<std::string>(), vec));
404  }
405  else if (elem1.value().isGenericDict()) {
406  for (const auto& elem2 : elem1.value().toGenericDict()) {
407  torch::Tensor tensor = elem2.value().toTensor();
408  std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
409  infer_output.push_back(
410  NuGraphOutput(elem1.key().to<std::string>() + "_" + elem2.key().to<std::string>(), vec));
411  }
412  }
413  }
414 
415  // Write the outputs to the output root file
416  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
417  _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output);
418  }
419 }
420 
Float_t x
Definition: compare.C:6
NuGraphInference & operator=(NuGraphInference const &)=delete
torch::jit::script::Module model
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
Declaration of signal hit object.
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.cc:6
vector< std::string > planes
boost::graph_traits< ModuleGraph >::edge_descriptor Edge
Definition: ModuleGraph.h:24
void produce(art::Event &e) override
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
auto array(Array const &a)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:250
#define DEFINE_ART_MODULE(klass)
Definition: ModuleMacros.h:65
T get(std::string const &key) const
Definition: ParameterSet.h:314
vector< vector< float > > devs
Float_t d
Definition: plot.C:235
vector< vector< float > > avgs
ProducesCollector & producesCollector() noexcept
Char_t n[5]
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
NuGraphInference(fhicl::ParameterSet const &p)
2D representation of charge deposited in the TDC/wire plane
Definition: Hit.h:46
Double_t sum
Definition: plot.C:31
Float_t e
Definition: plot.C:35
bool operator==(infinite_endcount_iterator< T > const &, count_iterator< T > const &)
Definition: counter.h:277
decltype(auto) constexpr empty(T &&obj)
ADL-aware version of std::empty.
Definition: StdUtils.h:109
std::unique_ptr< LoaderToolBase > _loaderTool