LArSoft  v10_06_00
Liquid Argon Software toolkit - https://larsoft.org/
NuGraphInference Class Reference
Inheritance diagram for NuGraphInference:
art::EDProducer art::detail::Producer art::detail::LegacyModule art::Modifier art::ModuleBase art::ProductRegistryHelper

Public Types

using ModuleType = EDProducer
 
template<typename UserConfig , typename KeysToIgnore = void>
using Table = Modifier::Table< UserConfig, KeysToIgnore >
 

Public Member Functions

 NuGraphInference (fhicl::ParameterSet const &p)
 
 NuGraphInference (NuGraphInference const &)=delete
 
 NuGraphInference (NuGraphInference &&)=delete
 
NuGraphInferenceoperator= (NuGraphInference const &)=delete
 
NuGraphInferenceoperator= (NuGraphInference &&)=delete
 
void produce (art::Event &e) override
 
void doBeginJob (SharedResources const &resources)
 
void doEndJob ()
 
void doRespondToOpenInputFile (FileBlock const &fb)
 
void doRespondToCloseInputFile (FileBlock const &fb)
 
void doRespondToOpenOutputFiles (FileBlock const &fb)
 
void doRespondToCloseOutputFiles (FileBlock const &fb)
 
bool doBeginRun (RunPrincipal &rp, ModuleContext const &mc)
 
bool doEndRun (RunPrincipal &rp, ModuleContext const &mc)
 
bool doBeginSubRun (SubRunPrincipal &srp, ModuleContext const &mc)
 
bool doEndSubRun (SubRunPrincipal &srp, ModuleContext const &mc)
 
bool doEvent (EventPrincipal &ep, ModuleContext const &mc, std::atomic< std::size_t > &counts_run, std::atomic< std::size_t > &counts_passed, std::atomic< std::size_t > &counts_failed)
 
void fillProductDescriptions ()
 
void registerProducts (ProductDescriptions &productsToRegister)
 
ModuleDescription const & moduleDescription () const
 
void setModuleDescription (ModuleDescription const &)
 
std::array< std::vector< ProductInfo >, NumBranchTypes > const & getConsumables () const
 
void sortConsumables (std::string const &current_process_name)
 
std::unique_ptr< Worker > makeWorker (WorkerParams const &wp)
 
template<typename T , BranchType BT>
ViewToken< T > consumesView (InputTag const &tag)
 
template<typename T , BranchType BT>
ViewToken< T > mayConsumeView (InputTag const &tag)
 

Protected Member Functions

ConsumesCollector & consumesCollector ()
 
template<typename T , BranchType = InEvent>
ProductToken< T > consumes (InputTag const &)
 
template<typename Element , BranchType = InEvent>
ViewToken< Element > consumesView (InputTag const &)
 
template<typename T , BranchType = InEvent>
void consumesMany ()
 
template<typename T , BranchType = InEvent>
ProductToken< T > mayConsume (InputTag const &)
 
template<typename Element , BranchType = InEvent>
ViewToken< Element > mayConsumeView (InputTag const &)
 
template<typename T , BranchType = InEvent>
void mayConsumeMany ()
 

Private Attributes

vector< std::string > planes
 
size_t minHits
 
bool debug
 
vector< vector< float > > avgs
 
vector< vector< float > > devs
 
vector< float > pos_norm
 
torch::jit::script::Module model
 
std::unique_ptr< LoaderToolBase_loaderTool
 
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
 

Detailed Description

Definition at line 72 of file NuGraphInference_module.cc.

Member Typedef Documentation

Definition at line 17 of file EDProducer.h.

template<typename UserConfig , typename KeysToIgnore = void>
using art::detail::Producer::Table = Modifier::Table<UserConfig, KeysToIgnore>
inherited

Definition at line 26 of file Producer.h.

Constructor & Destructor Documentation

NuGraphInference::NuGraphInference ( fhicl::ParameterSet const &  p)
explicit

Definition at line 99 of file NuGraphInference_module.cc.

References _decoderToolsVec, _loaderTool, avgs, debug, devs, fhicl::ParameterSet::get(), minHits, model, planes, pos_norm, and art::ProductRegistryHelper::producesCollector().

100  : EDProducer{p}
101  , planes(p.get<vector<std::string>>("planes"))
102  , minHits(p.get<size_t>("minHits"))
103  , debug(p.get<bool>("debug"))
104  , pos_norm(p.get<vector<float>>("pos_norm"))
105 {
106 
107  for (size_t ip = 0; ip < planes.size(); ++ip) {
108  avgs.push_back(p.get<vector<float>>("avgs_" + planes[ip]));
109  devs.push_back(p.get<vector<float>>("devs_" + planes[ip]));
110  }
111 
112  // Loader Tool
113  _loaderTool = art::make_tool<LoaderToolBase>(p.get<fhicl::ParameterSet>("LoaderTool"));
114  _loaderTool->setDebugAndPlanes(debug, planes);
115 
116  // configure and construct Decoder Tools
117  auto const tool_psets = p.get<fhicl::ParameterSet>("DecoderTools");
118  for (auto const& tool_pset_labels : tool_psets.get_pset_names()) {
119  std::cout << "decoder lablel: " << tool_pset_labels << std::endl;
120  auto const tool_pset = tool_psets.get<fhicl::ParameterSet>(tool_pset_labels);
121  _decoderToolsVec.push_back(art::make_tool<DecoderToolBase>(tool_pset));
122  _decoderToolsVec.back()->setDebugAndPlanes(debug, planes);
123  _decoderToolsVec.back()->declareProducts(producesCollector());
124  }
125 
126  cet::search_path sp("FW_SEARCH_PATH");
127  model = torch::jit::load(sp.find_file(p.get<std::string>("modelFileName")));
128 }
torch::jit::script::Module model
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.cc:6
vector< std::string > planes
T get(std::string const &key) const
Definition: ParameterSet.h:314
vector< vector< float > > devs
vector< vector< float > > avgs
ProducesCollector & producesCollector() noexcept
std::unique_ptr< LoaderToolBase > _loaderTool
NuGraphInference::NuGraphInference ( NuGraphInference const &  )
delete
NuGraphInference::NuGraphInference ( NuGraphInference &&  )
delete

Member Function Documentation

template<typename T , BranchType BT>
ProductToken< T > art::ModuleBase::consumes ( InputTag const &  tag)
protectedinherited

Definition at line 61 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::consumes().

62  {
63  return collector_.consumes<T, BT>(tag);
64  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
ProductToken< T > consumes(InputTag const &)
ConsumesCollector & art::ModuleBase::consumesCollector ( )
protectedinherited

Definition at line 57 of file ModuleBase.cc.

References art::ModuleBase::collector_.

58  {
59  return collector_;
60  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
template<typename T , BranchType BT>
void art::ModuleBase::consumesMany ( )
protectedinherited

Definition at line 75 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::consumesMany().

76  {
77  collector_.consumesMany<T, BT>();
78  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
template<typename Element , BranchType = InEvent>
ViewToken<Element> art::ModuleBase::consumesView ( InputTag const &  )
protectedinherited
template<typename T , BranchType BT>
ViewToken<T> art::ModuleBase::consumesView ( InputTag const &  tag)
inherited

Definition at line 68 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::consumesView().

69  {
70  return collector_.consumesView<T, BT>(tag);
71  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
ViewToken< Element > consumesView(InputTag const &)
void art::detail::Producer::doBeginJob ( SharedResources const &  resources)
inherited

Definition at line 22 of file Producer.cc.

References art::detail::Producer::beginJobWithFrame(), and art::detail::Producer::setupQueues().

23  {
24  setupQueues(resources);
25  ProcessingFrame const frame{ScheduleID{}};
26  beginJobWithFrame(frame);
27  }
virtual void setupQueues(SharedResources const &)=0
virtual void beginJobWithFrame(ProcessingFrame const &)=0
bool art::detail::Producer::doBeginRun ( RunPrincipal rp,
ModuleContext const &  mc 
)
inherited

Definition at line 65 of file Producer.cc.

References art::detail::Producer::beginRunWithFrame(), art::RangeSet::forRun(), art::RunPrincipal::makeRun(), r, art::RunPrincipal::runID(), and art::ModuleContext::scheduleID().

66  {
67  auto r = rp.makeRun(mc, RangeSet::forRun(rp.runID()));
68  ProcessingFrame const frame{mc.scheduleID()};
69  beginRunWithFrame(r, frame);
70  r.commitProducts();
71  return true;
72  }
TRandom r
Definition: spectrum.C:23
virtual void beginRunWithFrame(Run &, ProcessingFrame const &)=0
static RangeSet forRun(RunID)
Definition: RangeSet.cc:51
bool art::detail::Producer::doBeginSubRun ( SubRunPrincipal srp,
ModuleContext const &  mc 
)
inherited

Definition at line 85 of file Producer.cc.

References art::detail::Producer::beginSubRunWithFrame(), art::RangeSet::forSubRun(), art::SubRunPrincipal::makeSubRun(), art::ModuleContext::scheduleID(), and art::SubRunPrincipal::subRunID().

86  {
87  auto sr = srp.makeSubRun(mc, RangeSet::forSubRun(srp.subRunID()));
88  ProcessingFrame const frame{mc.scheduleID()};
89  beginSubRunWithFrame(sr, frame);
90  sr.commitProducts();
91  return true;
92  }
virtual void beginSubRunWithFrame(SubRun &, ProcessingFrame const &)=0
static RangeSet forSubRun(SubRunID)
Definition: RangeSet.cc:57
void art::detail::Producer::doEndJob ( )
inherited

Definition at line 30 of file Producer.cc.

References art::detail::Producer::endJobWithFrame().

31  {
32  ProcessingFrame const frame{ScheduleID{}};
33  endJobWithFrame(frame);
34  }
virtual void endJobWithFrame(ProcessingFrame const &)=0
bool art::detail::Producer::doEndRun ( RunPrincipal rp,
ModuleContext const &  mc 
)
inherited

Definition at line 75 of file Producer.cc.

References art::detail::Producer::endRunWithFrame(), art::RunPrincipal::makeRun(), r, art::ModuleContext::scheduleID(), and art::Principal::seenRanges().

76  {
77  auto r = rp.makeRun(mc, rp.seenRanges());
78  ProcessingFrame const frame{mc.scheduleID()};
79  endRunWithFrame(r, frame);
80  r.commitProducts();
81  return true;
82  }
TRandom r
Definition: spectrum.C:23
virtual void endRunWithFrame(Run &, ProcessingFrame const &)=0
bool art::detail::Producer::doEndSubRun ( SubRunPrincipal srp,
ModuleContext const &  mc 
)
inherited

Definition at line 95 of file Producer.cc.

References art::detail::Producer::endSubRunWithFrame(), art::SubRunPrincipal::makeSubRun(), art::ModuleContext::scheduleID(), and art::Principal::seenRanges().

96  {
97  auto sr = srp.makeSubRun(mc, srp.seenRanges());
98  ProcessingFrame const frame{mc.scheduleID()};
99  endSubRunWithFrame(sr, frame);
100  sr.commitProducts();
101  return true;
102  }
virtual void endSubRunWithFrame(SubRun &, ProcessingFrame const &)=0
bool art::detail::Producer::doEvent ( EventPrincipal ep,
ModuleContext const &  mc,
std::atomic< std::size_t > &  counts_run,
std::atomic< std::size_t > &  counts_passed,
std::atomic< std::size_t > &  counts_failed 
)
inherited

Definition at line 105 of file Producer.cc.

References art::detail::Producer::checkPutProducts_, e, art::EventPrincipal::makeEvent(), art::detail::Producer::produceWithFrame(), and art::ModuleContext::scheduleID().

110  {
111  auto e = ep.makeEvent(mc);
112  ++counts_run;
113  ProcessingFrame const frame{mc.scheduleID()};
114  produceWithFrame(e, frame);
115  e.commitProducts(checkPutProducts_, &expectedProducts<InEvent>());
116  ++counts_passed;
117  return true;
118  }
bool const checkPutProducts_
Definition: Producer.h:70
Float_t e
Definition: plot.C:35
virtual void produceWithFrame(Event &, ProcessingFrame const &)=0
void art::detail::Producer::doRespondToCloseInputFile ( FileBlock const &  fb)
inherited

Definition at line 44 of file Producer.cc.

References art::detail::Producer::respondToCloseInputFileWithFrame().

45  {
46  ProcessingFrame const frame{ScheduleID{}};
48  }
virtual void respondToCloseInputFileWithFrame(FileBlock const &, ProcessingFrame const &)=0
TFile fb("Li6.root")
void art::detail::Producer::doRespondToCloseOutputFiles ( FileBlock const &  fb)
inherited

Definition at line 58 of file Producer.cc.

References art::detail::Producer::respondToCloseOutputFilesWithFrame().

59  {
60  ProcessingFrame const frame{ScheduleID{}};
62  }
virtual void respondToCloseOutputFilesWithFrame(FileBlock const &, ProcessingFrame const &)=0
TFile fb("Li6.root")
void art::detail::Producer::doRespondToOpenInputFile ( FileBlock const &  fb)
inherited

Definition at line 37 of file Producer.cc.

References art::detail::Producer::respondToOpenInputFileWithFrame().

38  {
39  ProcessingFrame const frame{ScheduleID{}};
41  }
virtual void respondToOpenInputFileWithFrame(FileBlock const &, ProcessingFrame const &)=0
TFile fb("Li6.root")
void art::detail::Producer::doRespondToOpenOutputFiles ( FileBlock const &  fb)
inherited

Definition at line 51 of file Producer.cc.

References art::detail::Producer::respondToOpenOutputFilesWithFrame().

52  {
53  ProcessingFrame const frame{ScheduleID{}};
55  }
virtual void respondToOpenOutputFilesWithFrame(FileBlock const &, ProcessingFrame const &)=0
TFile fb("Li6.root")
void art::Modifier::fillProductDescriptions ( )
inherited

Definition at line 10 of file Modifier.cc.

References art::ProductRegistryHelper::fillDescriptions(), and art::ModuleBase::moduleDescription().

11  {
13  }
void fillDescriptions(ModuleDescription const &md)
ModuleDescription const & moduleDescription() const
Definition: ModuleBase.cc:13
std::array< std::vector< ProductInfo >, NumBranchTypes > const & art::ModuleBase::getConsumables ( ) const
inherited

Definition at line 43 of file ModuleBase.cc.

References art::ModuleBase::collector_, and art::ConsumesCollector::getConsumables().

44  {
45  return collector_.getConsumables();
46  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
std::array< std::vector< ProductInfo >, NumBranchTypes > const & getConsumables() const
std::unique_ptr< Worker > art::ModuleBase::makeWorker ( WorkerParams const &  wp)
inherited

Definition at line 37 of file ModuleBase.cc.

References art::ModuleBase::doMakeWorker(), and art::NumBranchTypes.

38  {
39  return doMakeWorker(wp);
40  }
virtual std::unique_ptr< Worker > doMakeWorker(WorkerParams const &wp)=0
template<typename T , BranchType BT>
ProductToken< T > art::ModuleBase::mayConsume ( InputTag const &  tag)
protectedinherited

Definition at line 82 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::mayConsume().

83  {
84  return collector_.mayConsume<T, BT>(tag);
85  }
ProductToken< T > mayConsume(InputTag const &)
ConsumesCollector collector_
Definition: ModuleBase.h:56
template<typename T , BranchType BT>
void art::ModuleBase::mayConsumeMany ( )
protectedinherited

Definition at line 96 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::mayConsumeMany().

97  {
98  collector_.mayConsumeMany<T, BT>();
99  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
template<typename Element , BranchType = InEvent>
ViewToken<Element> art::ModuleBase::mayConsumeView ( InputTag const &  )
protectedinherited
template<typename T , BranchType BT>
ViewToken<T> art::ModuleBase::mayConsumeView ( InputTag const &  tag)
inherited

Definition at line 89 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::mayConsumeView().

90  {
91  return collector_.mayConsumeView<T, BT>(tag);
92  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
ViewToken< Element > mayConsumeView(InputTag const &)
ModuleDescription const & art::ModuleBase::moduleDescription ( ) const
inherited

Definition at line 13 of file ModuleBase.cc.

References art::errors::LogicError.

Referenced by art::OutputModule::doRespondToOpenInputFile(), art::OutputModule::doWriteEvent(), art::Modifier::fillProductDescriptions(), art::OutputModule::makePlugins_(), art::OutputWorker::OutputWorker(), reco::shower::LArPandoraModularShowerCreation::produce(), art::Modifier::registerProducts(), and art::OutputModule::registerProducts().

14  {
15  if (md_.has_value()) {
16  return *md_;
17  }
18 
20  "There was an error while calling moduleDescription().\n"}
21  << "The moduleDescription() base-class member function cannot be called\n"
22  "during module construction. To determine which module is "
23  "responsible\n"
24  "for calling it, find the '<module type>:<module "
25  "label>@Construction'\n"
26  "tag in the message prefix above. Please contact artists@fnal.gov\n"
27  "for guidance.\n";
28  }
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::optional< ModuleDescription > md_
Definition: ModuleBase.h:55
NuGraphInference& NuGraphInference::operator= ( NuGraphInference const &  )
delete
NuGraphInference& NuGraphInference::operator= ( NuGraphInference &&  )
delete
void NuGraphInference::produce ( art::Event e)
overridevirtual

Implements art::EDProducer.

Definition at line 130 of file NuGraphInference_module.cc.

References _decoderToolsVec, _loaderTool, avgs, util::begin(), d, debug, DEFINE_ART_MODULE, devs, e, util::empty(), util::end(), minHits, model, n, util::details::operator==(), fhicl::other, planes, pos_norm, and x.

131 {
132 
133  //
134  // Load the data and fill the graph inputs
135  //
136  vector<art::Ptr<Hit>> hitlist;
137  vector<vector<size_t>> idsmap;
138  vector<NuGraphInput> graphinputs;
139  _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
140 
141  if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl;
142  if (hitlist.size() < minHits) {
143  // Writing the empty outputs to the output root file
144  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
145  _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap);
146  }
147  return;
148  }
149 
150  //
151  // libTorch-specific section: requires extracting inputs, create graph, run inference
152  //
153  const vector<int32_t>* spids = nullptr;
154  const vector<int32_t>* hitids_u = nullptr;
155  const vector<int32_t>* hitids_v = nullptr;
156  const vector<int32_t>* hitids_y = nullptr;
157  const vector<int32_t>* hit_plane = nullptr;
158  const vector<float>* hit_time = nullptr;
159  const vector<int32_t>* hit_wire = nullptr;
160  const vector<float>* hit_integral = nullptr;
161  const vector<float>* hit_rms = nullptr;
162  for (const auto& gi : graphinputs) {
163  if (gi.input_name == "spacepoint_table_spacepoint_id")
164  spids = &gi.input_int32_vec;
165  else if (gi.input_name == "spacepoint_table_hit_id_u")
166  hitids_u = &gi.input_int32_vec;
167  else if (gi.input_name == "spacepoint_table_hit_id_v")
168  hitids_v = &gi.input_int32_vec;
169  else if (gi.input_name == "spacepoint_table_hit_id_y")
170  hitids_y = &gi.input_int32_vec;
171  else if (gi.input_name == "hit_table_local_plane")
172  hit_plane = &gi.input_int32_vec;
173  else if (gi.input_name == "hit_table_local_time")
174  hit_time = &gi.input_float_vec;
175  else if (gi.input_name == "hit_table_local_wire")
176  hit_wire = &gi.input_int32_vec;
177  else if (gi.input_name == "hit_table_integral")
178  hit_integral = &gi.input_float_vec;
179  else if (gi.input_name == "hit_table_rms")
180  hit_rms = &gi.input_float_vec;
181  }
182 
183  // Reverse lookup from key to index in plane index
184  vector<size_t> idsmapRev(hitlist.size(), hitlist.size());
185  for (const auto& ipv : idsmap) {
186  for (size_t ih = 0; ih < ipv.size(); ih++) {
187  idsmapRev[ipv[ih]] = ih;
188  }
189  }
190 
191  struct Edge {
192  size_t n1;
193  size_t n2;
194  bool operator==(const Edge& other) const
195  {
196  if (this->n1 == other.n1 && this->n2 == other.n2)
197  return true;
198  else
199  return false;
200  };
201  };
202 
203  // Delauney graph construction
204  auto start_preprocess1 = std::chrono::high_resolution_clock::now();
205  vector<vector<Edge>> edge2d(planes.size(), vector<Edge>());
206  for (size_t p = 0; p < planes.size(); p++) {
207  vector<double> coords;
208  for (size_t i = 0; i < hit_plane->size(); ++i) {
209  if (size_t(hit_plane->at(i)) != p) continue;
210  coords.push_back(hit_time->at(i) * pos_norm[1]);
211  coords.push_back(hit_wire->at(i) * pos_norm[0]);
212  }
213  if (debug) std::cout << "Plane " << p << " has N hits=" << coords.size() / 2 << std::endl;
214  if (coords.size() / 2 < 3) { continue; }
215  delaunator::Delaunator d(coords);
216  if (debug) std::cout << "Found N triangles=" << d.triangles.size() / 3 << std::endl;
217  for (std::size_t i = 0; i < d.triangles.size(); i += 3) {
218  //create edges in both directions
219  Edge e;
220  e.n1 = d.triangles[i];
221  e.n2 = d.triangles[i + 1];
222  edge2d[p].push_back(e);
223  e.n1 = d.triangles[i + 1];
224  e.n2 = d.triangles[i];
225  edge2d[p].push_back(e);
226  //
227  e.n1 = d.triangles[i];
228  e.n2 = d.triangles[i + 2];
229  edge2d[p].push_back(e);
230  e.n1 = d.triangles[i + 2];
231  e.n2 = d.triangles[i];
232  edge2d[p].push_back(e);
233  //
234  e.n1 = d.triangles[i + 1];
235  e.n2 = d.triangles[i + 2];
236  edge2d[p].push_back(e);
237  e.n1 = d.triangles[i + 2];
238  e.n2 = d.triangles[i + 1];
239  edge2d[p].push_back(e);
240  //
241  }
242  //sort and cleanup duplicate edges
243  std::sort(edge2d[p].begin(), edge2d[p].end(), [](const auto& i, const auto& j) {
244  return (i.n1 != j.n1 ? i.n1 < j.n1 : i.n2 < j.n2);
245  });
246  if (debug) {
247  for (auto& e : edge2d[p]) {
248  std::cout << "sorted plane=" << p << " e1=" << e.n1 << " e2=" << e.n2 << std::endl;
249  }
250  }
251  edge2d[p].erase(std::unique(edge2d[p].begin(), edge2d[p].end()), edge2d[p].end());
252  }
253 
254  if (debug) {
255  for (size_t p = 0; p < planes.size(); p++) {
256  for (auto& e : edge2d[p]) {
257  std::cout << " plane=" << p << " e1=" << e.n1 << " e2=" << e.n2 << std::endl;
258  }
259  }
260  }
261  auto end_preprocess1 = std::chrono::high_resolution_clock::now();
262  std::chrono::duration<double> elapsed_preprocess1 = end_preprocess1 - start_preprocess1;
263 
264  // Nexus edges
265  auto start_preprocess2 = std::chrono::high_resolution_clock::now();
266  vector<vector<Edge>> edge3d(planes.size(), vector<Edge>());
267  for (size_t i = 0; i < spids->size(); ++i) {
268  if (hitids_u->at(i) >= 0) {
269  Edge e;
270  e.n1 = idsmapRev[hitids_u->at(i)];
271  e.n2 = spids->at(i);
272  edge3d[0].push_back(e);
273  }
274  if (hitids_v->at(i) >= 0) {
275  Edge e;
276  e.n1 = idsmapRev[hitids_v->at(i)];
277  e.n2 = spids->at(i);
278  edge3d[1].push_back(e);
279  }
280  if (hitids_y->at(i) >= 0) {
281  Edge e;
282  e.n1 = idsmapRev[hitids_y->at(i)];
283  e.n2 = spids->at(i);
284  edge3d[2].push_back(e);
285  }
286  }
287 
288  // Prepare inputs
289  auto x = torch::Dict<std::string, torch::Tensor>();
290  auto batch = torch::Dict<std::string, torch::Tensor>();
291  for (size_t p = 0; p < planes.size(); p++) {
292  vector<float> nodeft;
293  for (size_t i = 0; i < hit_plane->size(); ++i) {
294  if (size_t(hit_plane->at(i)) != p) continue;
295  nodeft.push_back((hit_wire->at(i) * pos_norm[0] - avgs[hit_plane->at(i)][0]) /
296  devs[hit_plane->at(i)][0]);
297  nodeft.push_back((hit_time->at(i) * pos_norm[1] - avgs[hit_plane->at(i)][1]) /
298  devs[hit_plane->at(i)][1]);
299  nodeft.push_back((hit_integral->at(i) - avgs[hit_plane->at(i)][2]) /
300  devs[hit_plane->at(i)][2]);
301  nodeft.push_back((hit_rms->at(i) - avgs[hit_plane->at(i)][3]) / devs[hit_plane->at(i)][3]);
302  }
303  long int dim = nodeft.size() / 4;
304  torch::Tensor ix = torch::zeros({dim, 4}, torch::dtype(torch::kFloat32));
305  if (debug) {
306  std::cout << "plane=" << p << std::endl;
307  std::cout << std::scientific;
308  for (size_t n = 0; n < nodeft.size(); n = n + 4) {
309  std::cout << nodeft[n] << " " << nodeft[n + 1] << " " << nodeft[n + 2] << " "
310  << nodeft[n + 3] << " " << std::endl;
311  }
312  }
313  for (size_t n = 0; n < nodeft.size(); n = n + 4) {
314  ix[n / 4][0] = nodeft[n];
315  ix[n / 4][1] = nodeft[n + 1];
316  ix[n / 4][2] = nodeft[n + 2];
317  ix[n / 4][3] = nodeft[n + 3];
318  }
319  x.insert(planes[p], ix);
320  torch::Tensor ib = torch::zeros({dim}, torch::dtype(torch::kInt64));
321  batch.insert(planes[p], ib);
322  }
323 
324  auto edge_index_plane = torch::Dict<std::string, torch::Tensor>();
325  for (size_t p = 0; p < planes.size(); p++) {
326  long int dim = edge2d[p].size();
327  torch::Tensor ix = torch::zeros({2, dim}, torch::dtype(torch::kInt64));
328  for (size_t n = 0; n < edge2d[p].size(); n++) {
329  ix[0][n] = int(edge2d[p][n].n1);
330  ix[1][n] = int(edge2d[p][n].n2);
331  }
332  edge_index_plane.insert(planes[p], ix);
333  if (debug) {
334  std::cout << "plane=" << p << std::endl;
335  std::cout << "2d edge size=" << edge2d[p].size() << std::endl;
336  for (size_t n = 0; n < edge2d[p].size(); n++) {
337  std::cout << edge2d[p][n].n1 << " ";
338  }
339  std::cout << std::endl;
340  for (size_t n = 0; n < edge2d[p].size(); n++) {
341  std::cout << edge2d[p][n].n2 << " ";
342  }
343  std::cout << std::endl;
344  }
345  }
346 
347  auto edge_index_nexus = torch::Dict<std::string, torch::Tensor>();
348  for (size_t p = 0; p < planes.size(); p++) {
349  long int dim = edge3d[p].size();
350  torch::Tensor ix = torch::zeros({2, dim}, torch::dtype(torch::kInt64));
351  for (size_t n = 0; n < edge3d[p].size(); n++) {
352  ix[0][n] = int(edge3d[p][n].n1);
353  ix[1][n] = int(edge3d[p][n].n2);
354  }
355  edge_index_nexus.insert(planes[p], ix);
356  if (debug) {
357  std::cout << "plane=" << p << std::endl;
358  std::cout << "3d edge size=" << edge3d[p].size() << std::endl;
359  for (size_t n = 0; n < edge3d[p].size(); n++) {
360  std::cout << edge3d[p][n].n1 << " ";
361  }
362  std::cout << std::endl;
363  for (size_t n = 0; n < edge3d[p].size(); n++) {
364  std::cout << edge3d[p][n].n2 << " ";
365  }
366  std::cout << std::endl;
367  }
368  }
369 
370  long int spdim = spids->size();
371  auto nexus = torch::empty({spdim, 0}, torch::dtype(torch::kFloat32));
372 
373  std::vector<torch::jit::IValue> inputs;
374  inputs.push_back(x);
375  inputs.push_back(edge_index_plane);
376  inputs.push_back(edge_index_nexus);
377  inputs.push_back(nexus);
378  inputs.push_back(batch);
379 
380  // Run inference
381  auto end_preprocess2 = std::chrono::high_resolution_clock::now();
382  std::chrono::duration<double> elapsed_preprocess2 = end_preprocess2 - start_preprocess2;
383  if (debug) std::cout << "FORWARD!" << std::endl;
384  auto start = std::chrono::high_resolution_clock::now();
385  auto outputs = model.forward(inputs).toGenericDict();
386  auto end = std::chrono::high_resolution_clock::now();
387  std::chrono::duration<double> elapsed = end - start;
388  if (debug) {
389  std::cout << "Time taken for inference: "
390  << elapsed_preprocess1.count() + elapsed_preprocess2.count() + elapsed.count()
391  << " seconds" << std::endl;
392  std::cout << "output =" << outputs << std::endl;
393  }
394 
395  //
396  // Get pointers to the result returned and write to the event
397  //
398  vector<NuGraphOutput> infer_output;
399  for (const auto& elem1 : outputs) {
400  if (elem1.value().isTensor()) {
401  torch::Tensor tensor = elem1.value().toTensor();
402  std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
403  infer_output.push_back(NuGraphOutput(elem1.key().to<std::string>(), vec));
404  }
405  else if (elem1.value().isGenericDict()) {
406  for (const auto& elem2 : elem1.value().toGenericDict()) {
407  torch::Tensor tensor = elem2.value().toTensor();
408  std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
409  infer_output.push_back(
410  NuGraphOutput(elem1.key().to<std::string>() + "_" + elem2.key().to<std::string>(), vec));
411  }
412  }
413  }
414 
415  // Write the outputs to the output root file
416  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
417  _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output);
418  }
419 }
Float_t x
Definition: compare.C:6
torch::jit::script::Module model
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
vector< std::string > planes
boost::graph_traits< ModuleGraph >::edge_descriptor Edge
Definition: ModuleGraph.h:24
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
vector< vector< float > > devs
Float_t d
Definition: plot.C:235
vector< vector< float > > avgs
Char_t n[5]
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
Float_t e
Definition: plot.C:35
bool operator==(infinite_endcount_iterator< T > const &, count_iterator< T > const &)
Definition: counter.h:277
decltype(auto) constexpr empty(T &&obj)
ADL-aware version of std::empty.
Definition: StdUtils.h:109
std::unique_ptr< LoaderToolBase > _loaderTool
void art::Modifier::registerProducts ( ProductDescriptions productsToRegister)
inherited

Definition at line 16 of file Modifier.cc.

References art::ModuleBase::moduleDescription(), and art::ProductRegistryHelper::registerProducts().

17  {
18  ProductRegistryHelper::registerProducts(productsToRegister,
20  }
void registerProducts(ProductDescriptions &productsToRegister, ModuleDescription const &md)
ModuleDescription const & moduleDescription() const
Definition: ModuleBase.cc:13
void art::ModuleBase::setModuleDescription ( ModuleDescription const &  md)
inherited

Definition at line 31 of file ModuleBase.cc.

References art::ModuleBase::md_.

32  {
33  md_ = md;
34  }
std::optional< ModuleDescription > md_
Definition: ModuleBase.h:55
void art::ModuleBase::sortConsumables ( std::string const &  current_process_name)
inherited

Definition at line 49 of file ModuleBase.cc.

References art::ModuleBase::collector_, and art::ConsumesCollector::sortConsumables().

50  {
51  // Now that we know we have seen all the consumes declarations,
52  // sort the results for fast lookup later.
53  collector_.sortConsumables(current_process_name);
54  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
void sortConsumables(std::string const &current_process_name)

Member Data Documentation

std::vector<std::unique_ptr<DecoderToolBase> > NuGraphInference::_decoderToolsVec
private

Definition at line 96 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

std::unique_ptr<LoaderToolBase> NuGraphInference::_loaderTool
private

Definition at line 94 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

vector<vector<float> > NuGraphInference::avgs
private

Definition at line 89 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

bool NuGraphInference::debug
private

Definition at line 88 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

vector<vector<float> > NuGraphInference::devs
private

Definition at line 90 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

size_t NuGraphInference::minHits
private

Definition at line 87 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

torch::jit::script::Module NuGraphInference::model
private

Definition at line 92 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

vector<std::string> NuGraphInference::planes
private

Definition at line 86 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

vector<float> NuGraphInference::pos_norm
private

Definition at line 91 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().


The documentation for this class was generated from the following file: