LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
NuGraphInference Class Reference
Inheritance diagram for NuGraphInference:
art::EDProducer art::detail::Producer art::detail::LegacyModule art::Modifier art::ModuleBase art::ProductRegistryHelper

Public Types

using ModuleType = EDProducer
 
template<typename UserConfig , typename KeysToIgnore = void>
using Table = Modifier::Table< UserConfig, KeysToIgnore >
 

Public Member Functions

 NuGraphInference (fhicl::ParameterSet const &p)
 
 NuGraphInference (NuGraphInference const &)=delete
 
 NuGraphInference (NuGraphInference &&)=delete
 
NuGraphInferenceoperator= (NuGraphInference const &)=delete
 
NuGraphInferenceoperator= (NuGraphInference &&)=delete
 
void produce (art::Event &e) override
 
void doBeginJob (SharedResources const &resources)
 
void doEndJob ()
 
void doRespondToOpenInputFile (FileBlock const &fb)
 
void doRespondToCloseInputFile (FileBlock const &fb)
 
void doRespondToOpenOutputFiles (FileBlock const &fb)
 
void doRespondToCloseOutputFiles (FileBlock const &fb)
 
bool doBeginRun (RunPrincipal &rp, ModuleContext const &mc)
 
bool doEndRun (RunPrincipal &rp, ModuleContext const &mc)
 
bool doBeginSubRun (SubRunPrincipal &srp, ModuleContext const &mc)
 
bool doEndSubRun (SubRunPrincipal &srp, ModuleContext const &mc)
 
bool doEvent (EventPrincipal &ep, ModuleContext const &mc, std::atomic< std::size_t > &counts_run, std::atomic< std::size_t > &counts_passed, std::atomic< std::size_t > &counts_failed)
 
void fillProductDescriptions ()
 
void registerProducts (ProductDescriptions &productsToRegister)
 
ModuleDescription const & moduleDescription () const
 
void setModuleDescription (ModuleDescription const &)
 
std::array< std::vector< ProductInfo >, NumBranchTypes > const & getConsumables () const
 
void sortConsumables (std::string const &current_process_name)
 
std::unique_ptr< Worker > makeWorker (WorkerParams const &wp)
 
template<typename T , BranchType BT>
ViewToken< T > consumesView (InputTag const &tag)
 
template<typename T , BranchType BT>
ViewToken< T > mayConsumeView (InputTag const &tag)
 

Protected Member Functions

ConsumesCollector & consumesCollector ()
 
template<typename T , BranchType = InEvent>
ProductToken< T > consumes (InputTag const &)
 
template<typename Element , BranchType = InEvent>
ViewToken< Element > consumesView (InputTag const &)
 
template<typename T , BranchType = InEvent>
void consumesMany ()
 
template<typename T , BranchType = InEvent>
ProductToken< T > mayConsume (InputTag const &)
 
template<typename Element , BranchType = InEvent>
ViewToken< Element > mayConsumeView (InputTag const &)
 
template<typename T , BranchType = InEvent>
void mayConsumeMany ()
 

Private Attributes

vector< std::string > planes
 
size_t minHits
 
bool debug
 
vector< vector< float > > avgs
 
vector< vector< float > > devs
 
vector< float > pos_norm
 
torch::jit::script::Module model
 
std::unique_ptr< LoaderToolBase_loaderTool
 
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
 

Detailed Description

Definition at line 71 of file NuGraphInference_module.cc.

Member Typedef Documentation

Definition at line 17 of file EDProducer.h.

template<typename UserConfig , typename KeysToIgnore = void>
using art::detail::Producer::Table = Modifier::Table<UserConfig, KeysToIgnore>
inherited

Definition at line 26 of file Producer.h.

Constructor & Destructor Documentation

NuGraphInference::NuGraphInference ( fhicl::ParameterSet const &  p)
explicit

Definition at line 98 of file NuGraphInference_module.cc.

References _decoderToolsVec, _loaderTool, avgs, debug, devs, fhicl::ParameterSet::get(), minHits, model, planes, pos_norm, and art::ProductRegistryHelper::producesCollector().

99  : EDProducer{p}
100  , planes(p.get<vector<std::string>>("planes"))
101  , minHits(p.get<size_t>("minHits"))
102  , debug(p.get<bool>("debug"))
103  , pos_norm(p.get<vector<float>>("pos_norm"))
104 {
105 
106  for (size_t ip = 0; ip < planes.size(); ++ip) {
107  avgs.push_back(p.get<vector<float>>("avgs_" + planes[ip]));
108  devs.push_back(p.get<vector<float>>("devs_" + planes[ip]));
109  }
110 
111  // Loader Tool
112  _loaderTool = art::make_tool<LoaderToolBase>(p.get<fhicl::ParameterSet>("LoaderTool"));
113  _loaderTool->setDebugAndPlanes(debug, planes);
114 
115  // configure and construct Decoder Tools
116  auto const tool_psets = p.get<fhicl::ParameterSet>("DecoderTools");
117  for (auto const& tool_pset_labels : tool_psets.get_pset_names()) {
118  std::cout << "decoder lablel: " << tool_pset_labels << std::endl;
119  auto const tool_pset = tool_psets.get<fhicl::ParameterSet>(tool_pset_labels);
120  _decoderToolsVec.push_back(art::make_tool<DecoderToolBase>(tool_pset));
121  _decoderToolsVec.back()->setDebugAndPlanes(debug, planes);
122  _decoderToolsVec.back()->declareProducts(producesCollector());
123  }
124 
125  cet::search_path sp("FW_SEARCH_PATH");
126  model = torch::jit::load(sp.find_file(p.get<std::string>("modelFileName")));
127 }
torch::jit::script::Module model
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.cc:6
vector< std::string > planes
T get(std::string const &key) const
Definition: ParameterSet.h:314
vector< vector< float > > devs
vector< vector< float > > avgs
ProducesCollector & producesCollector() noexcept
std::unique_ptr< LoaderToolBase > _loaderTool
NuGraphInference::NuGraphInference ( NuGraphInference const &  )
delete
NuGraphInference::NuGraphInference ( NuGraphInference &&  )
delete

Member Function Documentation

template<typename T , BranchType BT>
ProductToken< T > art::ModuleBase::consumes ( InputTag const &  tag)
protectedinherited

Definition at line 61 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::consumes().

62  {
63  return collector_.consumes<T, BT>(tag);
64  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
ProductToken< T > consumes(InputTag const &)
ConsumesCollector & art::ModuleBase::consumesCollector ( )
protectedinherited

Definition at line 57 of file ModuleBase.cc.

References art::ModuleBase::collector_.

58  {
59  return collector_;
60  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
template<typename T , BranchType BT>
void art::ModuleBase::consumesMany ( )
protectedinherited

Definition at line 75 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::consumesMany().

76  {
77  collector_.consumesMany<T, BT>();
78  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
template<typename Element , BranchType = InEvent>
ViewToken<Element> art::ModuleBase::consumesView ( InputTag const &  )
protectedinherited
template<typename T , BranchType BT>
ViewToken<T> art::ModuleBase::consumesView ( InputTag const &  tag)
inherited

Definition at line 68 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::consumesView().

69  {
70  return collector_.consumesView<T, BT>(tag);
71  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
ViewToken< Element > consumesView(InputTag const &)
void art::detail::Producer::doBeginJob ( SharedResources const &  resources)
inherited

Definition at line 22 of file Producer.cc.

References art::detail::Producer::beginJobWithFrame(), and art::detail::Producer::setupQueues().

23  {
24  setupQueues(resources);
25  ProcessingFrame const frame{ScheduleID{}};
26  beginJobWithFrame(frame);
27  }
virtual void setupQueues(SharedResources const &)=0
virtual void beginJobWithFrame(ProcessingFrame const &)=0
bool art::detail::Producer::doBeginRun ( RunPrincipal rp,
ModuleContext const &  mc 
)
inherited

Definition at line 65 of file Producer.cc.

References art::detail::Producer::beginRunWithFrame(), art::RangeSet::forRun(), art::RunPrincipal::makeRun(), r, art::RunPrincipal::runID(), and art::ModuleContext::scheduleID().

66  {
67  auto r = rp.makeRun(mc, RangeSet::forRun(rp.runID()));
68  ProcessingFrame const frame{mc.scheduleID()};
69  beginRunWithFrame(r, frame);
70  r.commitProducts();
71  return true;
72  }
TRandom r
Definition: spectrum.C:23
virtual void beginRunWithFrame(Run &, ProcessingFrame const &)=0
static RangeSet forRun(RunID)
Definition: RangeSet.cc:51
bool art::detail::Producer::doBeginSubRun ( SubRunPrincipal srp,
ModuleContext const &  mc 
)
inherited

Definition at line 85 of file Producer.cc.

References art::detail::Producer::beginSubRunWithFrame(), art::RangeSet::forSubRun(), art::SubRunPrincipal::makeSubRun(), art::ModuleContext::scheduleID(), and art::SubRunPrincipal::subRunID().

86  {
87  auto sr = srp.makeSubRun(mc, RangeSet::forSubRun(srp.subRunID()));
88  ProcessingFrame const frame{mc.scheduleID()};
89  beginSubRunWithFrame(sr, frame);
90  sr.commitProducts();
91  return true;
92  }
virtual void beginSubRunWithFrame(SubRun &, ProcessingFrame const &)=0
static RangeSet forSubRun(SubRunID)
Definition: RangeSet.cc:57
void art::detail::Producer::doEndJob ( )
inherited

Definition at line 30 of file Producer.cc.

References art::detail::Producer::endJobWithFrame().

31  {
32  ProcessingFrame const frame{ScheduleID{}};
33  endJobWithFrame(frame);
34  }
virtual void endJobWithFrame(ProcessingFrame const &)=0
bool art::detail::Producer::doEndRun ( RunPrincipal rp,
ModuleContext const &  mc 
)
inherited

Definition at line 75 of file Producer.cc.

References art::detail::Producer::endRunWithFrame(), art::RunPrincipal::makeRun(), r, art::ModuleContext::scheduleID(), and art::Principal::seenRanges().

76  {
77  auto r = rp.makeRun(mc, rp.seenRanges());
78  ProcessingFrame const frame{mc.scheduleID()};
79  endRunWithFrame(r, frame);
80  r.commitProducts();
81  return true;
82  }
TRandom r
Definition: spectrum.C:23
virtual void endRunWithFrame(Run &, ProcessingFrame const &)=0
bool art::detail::Producer::doEndSubRun ( SubRunPrincipal srp,
ModuleContext const &  mc 
)
inherited

Definition at line 95 of file Producer.cc.

References art::detail::Producer::endSubRunWithFrame(), art::SubRunPrincipal::makeSubRun(), art::ModuleContext::scheduleID(), and art::Principal::seenRanges().

96  {
97  auto sr = srp.makeSubRun(mc, srp.seenRanges());
98  ProcessingFrame const frame{mc.scheduleID()};
99  endSubRunWithFrame(sr, frame);
100  sr.commitProducts();
101  return true;
102  }
virtual void endSubRunWithFrame(SubRun &, ProcessingFrame const &)=0
bool art::detail::Producer::doEvent ( EventPrincipal ep,
ModuleContext const &  mc,
std::atomic< std::size_t > &  counts_run,
std::atomic< std::size_t > &  counts_passed,
std::atomic< std::size_t > &  counts_failed 
)
inherited

Definition at line 105 of file Producer.cc.

References art::detail::Producer::checkPutProducts_, e, art::EventPrincipal::makeEvent(), art::detail::Producer::produceWithFrame(), and art::ModuleContext::scheduleID().

110  {
111  auto e = ep.makeEvent(mc);
112  ++counts_run;
113  ProcessingFrame const frame{mc.scheduleID()};
114  produceWithFrame(e, frame);
115  e.commitProducts(checkPutProducts_, &expectedProducts<InEvent>());
116  ++counts_passed;
117  return true;
118  }
bool const checkPutProducts_
Definition: Producer.h:70
Float_t e
Definition: plot.C:35
virtual void produceWithFrame(Event &, ProcessingFrame const &)=0
void art::detail::Producer::doRespondToCloseInputFile ( FileBlock const &  fb)
inherited

Definition at line 44 of file Producer.cc.

References art::detail::Producer::respondToCloseInputFileWithFrame().

45  {
46  ProcessingFrame const frame{ScheduleID{}};
48  }
virtual void respondToCloseInputFileWithFrame(FileBlock const &, ProcessingFrame const &)=0
TFile fb("Li6.root")
void art::detail::Producer::doRespondToCloseOutputFiles ( FileBlock const &  fb)
inherited

Definition at line 58 of file Producer.cc.

References art::detail::Producer::respondToCloseOutputFilesWithFrame().

59  {
60  ProcessingFrame const frame{ScheduleID{}};
62  }
virtual void respondToCloseOutputFilesWithFrame(FileBlock const &, ProcessingFrame const &)=0
TFile fb("Li6.root")
void art::detail::Producer::doRespondToOpenInputFile ( FileBlock const &  fb)
inherited

Definition at line 37 of file Producer.cc.

References art::detail::Producer::respondToOpenInputFileWithFrame().

38  {
39  ProcessingFrame const frame{ScheduleID{}};
41  }
virtual void respondToOpenInputFileWithFrame(FileBlock const &, ProcessingFrame const &)=0
TFile fb("Li6.root")
void art::detail::Producer::doRespondToOpenOutputFiles ( FileBlock const &  fb)
inherited

Definition at line 51 of file Producer.cc.

References art::detail::Producer::respondToOpenOutputFilesWithFrame().

52  {
53  ProcessingFrame const frame{ScheduleID{}};
55  }
virtual void respondToOpenOutputFilesWithFrame(FileBlock const &, ProcessingFrame const &)=0
TFile fb("Li6.root")
void art::Modifier::fillProductDescriptions ( )
inherited

Definition at line 10 of file Modifier.cc.

References art::ProductRegistryHelper::fillDescriptions(), and art::ModuleBase::moduleDescription().

11  {
13  }
void fillDescriptions(ModuleDescription const &md)
ModuleDescription const & moduleDescription() const
Definition: ModuleBase.cc:13
std::array< std::vector< ProductInfo >, NumBranchTypes > const & art::ModuleBase::getConsumables ( ) const
inherited

Definition at line 43 of file ModuleBase.cc.

References art::ModuleBase::collector_, and art::ConsumesCollector::getConsumables().

44  {
45  return collector_.getConsumables();
46  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
std::array< std::vector< ProductInfo >, NumBranchTypes > const & getConsumables() const
std::unique_ptr< Worker > art::ModuleBase::makeWorker ( WorkerParams const &  wp)
inherited

Definition at line 37 of file ModuleBase.cc.

References art::ModuleBase::doMakeWorker(), and art::NumBranchTypes.

38  {
39  return doMakeWorker(wp);
40  }
virtual std::unique_ptr< Worker > doMakeWorker(WorkerParams const &wp)=0
template<typename T , BranchType BT>
ProductToken< T > art::ModuleBase::mayConsume ( InputTag const &  tag)
protectedinherited

Definition at line 82 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::mayConsume().

83  {
84  return collector_.mayConsume<T, BT>(tag);
85  }
ProductToken< T > mayConsume(InputTag const &)
ConsumesCollector collector_
Definition: ModuleBase.h:56
template<typename T , BranchType BT>
void art::ModuleBase::mayConsumeMany ( )
protectedinherited

Definition at line 96 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::mayConsumeMany().

97  {
98  collector_.mayConsumeMany<T, BT>();
99  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
template<typename Element , BranchType = InEvent>
ViewToken<Element> art::ModuleBase::mayConsumeView ( InputTag const &  )
protectedinherited
template<typename T , BranchType BT>
ViewToken<T> art::ModuleBase::mayConsumeView ( InputTag const &  tag)
inherited

Definition at line 89 of file ModuleBase.h.

References art::ModuleBase::collector_, and art::ConsumesCollector::mayConsumeView().

90  {
91  return collector_.mayConsumeView<T, BT>(tag);
92  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
ViewToken< Element > mayConsumeView(InputTag const &)
ModuleDescription const & art::ModuleBase::moduleDescription ( ) const
inherited

Definition at line 13 of file ModuleBase.cc.

References art::errors::LogicError.

Referenced by art::OutputModule::doRespondToOpenInputFile(), art::OutputModule::doWriteEvent(), art::Modifier::fillProductDescriptions(), art::OutputModule::makePlugins_(), art::OutputWorker::OutputWorker(), reco::shower::LArPandoraModularShowerCreation::produce(), art::Modifier::registerProducts(), and art::OutputModule::registerProducts().

14  {
15  if (md_.has_value()) {
16  return *md_;
17  }
18 
20  "There was an error while calling moduleDescription().\n"}
21  << "The moduleDescription() base-class member function cannot be called\n"
22  "during module construction. To determine which module is "
23  "responsible\n"
24  "for calling it, find the '<module type>:<module "
25  "label>@Construction'\n"
26  "tag in the message prefix above. Please contact artists@fnal.gov\n"
27  "for guidance.\n";
28  }
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::optional< ModuleDescription > md_
Definition: ModuleBase.h:55
NuGraphInference& NuGraphInference::operator= ( NuGraphInference const &  )
delete
NuGraphInference& NuGraphInference::operator= ( NuGraphInference &&  )
delete
void NuGraphInference::produce ( art::Event e)
overridevirtual

Implements art::EDProducer.

Definition at line 129 of file NuGraphInference_module.cc.

References _decoderToolsVec, _loaderTool, avgs, util::begin(), d, debug, DEFINE_ART_MODULE, devs, e, util::empty(), util::end(), minHits, model, n, util::details::operator==(), fhicl::other, planes, pos_norm, and x.

130 {
131 
132  //
133  // Load the data and fill the graph inputs
134  //
135  vector<art::Ptr<Hit>> hitlist;
136  vector<vector<size_t>> idsmap;
137  vector<NuGraphInput> graphinputs;
138  _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
139 
140  if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl;
141  if (hitlist.size() < minHits) {
142  // Writing the empty outputs to the output root file
143  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
144  _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap);
145  }
146  return;
147  }
148 
149  //
150  // libTorch-specific section: requires extracting inputs, create graph, run inference
151  //
152  const vector<int32_t>* spids = nullptr;
153  const vector<int32_t>* hitids_u = nullptr;
154  const vector<int32_t>* hitids_v = nullptr;
155  const vector<int32_t>* hitids_y = nullptr;
156  const vector<int32_t>* hit_plane = nullptr;
157  const vector<float>* hit_time = nullptr;
158  const vector<int32_t>* hit_wire = nullptr;
159  const vector<float>* hit_integral = nullptr;
160  const vector<float>* hit_rms = nullptr;
161  for (const auto& gi : graphinputs) {
162  if (gi.input_name == "spacepoint_table_spacepoint_id")
163  spids = &gi.input_int32_vec;
164  else if (gi.input_name == "spacepoint_table_hit_id_u")
165  hitids_u = &gi.input_int32_vec;
166  else if (gi.input_name == "spacepoint_table_hit_id_v")
167  hitids_v = &gi.input_int32_vec;
168  else if (gi.input_name == "spacepoint_table_hit_id_y")
169  hitids_y = &gi.input_int32_vec;
170  else if (gi.input_name == "hit_table_local_plane")
171  hit_plane = &gi.input_int32_vec;
172  else if (gi.input_name == "hit_table_local_time")
173  hit_time = &gi.input_float_vec;
174  else if (gi.input_name == "hit_table_local_wire")
175  hit_wire = &gi.input_int32_vec;
176  else if (gi.input_name == "hit_table_integral")
177  hit_integral = &gi.input_float_vec;
178  else if (gi.input_name == "hit_table_rms")
179  hit_rms = &gi.input_float_vec;
180  }
181 
182  // Reverse lookup from key to index in plane index
183  vector<size_t> idsmapRev(hitlist.size(), hitlist.size());
184  for (const auto& ipv : idsmap) {
185  for (size_t ih = 0; ih < ipv.size(); ih++) {
186  idsmapRev[ipv[ih]] = ih;
187  }
188  }
189 
190  struct Edge {
191  size_t n1;
192  size_t n2;
193  bool operator==(const Edge& other) const
194  {
195  if (this->n1 == other.n1 && this->n2 == other.n2)
196  return true;
197  else
198  return false;
199  };
200  };
201 
202  // Delauney graph construction
203  auto start_preprocess1 = std::chrono::high_resolution_clock::now();
204  vector<vector<Edge>> edge2d(planes.size(), vector<Edge>());
205  for (size_t p = 0; p < planes.size(); p++) {
206  vector<double> coords;
207  for (size_t i = 0; i < hit_plane->size(); ++i) {
208  if (size_t(hit_plane->at(i)) != p) continue;
209  coords.push_back(hit_time->at(i) * pos_norm[1]);
210  coords.push_back(hit_wire->at(i) * pos_norm[0]);
211  }
212  if (debug) std::cout << "Plane " << p << " has N hits=" << coords.size() / 2 << std::endl;
213  if (coords.size() / 2 < 3) { continue; }
214  delaunator::Delaunator d(coords);
215  if (debug) std::cout << "Found N triangles=" << d.triangles.size() / 3 << std::endl;
216  for (std::size_t i = 0; i < d.triangles.size(); i += 3) {
217  //create edges in both directions
218  Edge e;
219  e.n1 = d.triangles[i];
220  e.n2 = d.triangles[i + 1];
221  edge2d[p].push_back(e);
222  e.n1 = d.triangles[i + 1];
223  e.n2 = d.triangles[i];
224  edge2d[p].push_back(e);
225  //
226  e.n1 = d.triangles[i];
227  e.n2 = d.triangles[i + 2];
228  edge2d[p].push_back(e);
229  e.n1 = d.triangles[i + 2];
230  e.n2 = d.triangles[i];
231  edge2d[p].push_back(e);
232  //
233  e.n1 = d.triangles[i + 1];
234  e.n2 = d.triangles[i + 2];
235  edge2d[p].push_back(e);
236  e.n1 = d.triangles[i + 2];
237  e.n2 = d.triangles[i + 1];
238  edge2d[p].push_back(e);
239  //
240  }
241  //sort and cleanup duplicate edges
242  std::sort(edge2d[p].begin(), edge2d[p].end(), [](const auto& i, const auto& j) {
243  return (i.n1 != j.n1 ? i.n1 < j.n1 : i.n2 < j.n2);
244  });
245  if (debug) {
246  for (auto& e : edge2d[p]) {
247  std::cout << "sorted plane=" << p << " e1=" << e.n1 << " e2=" << e.n2 << std::endl;
248  }
249  }
250  edge2d[p].erase(std::unique(edge2d[p].begin(), edge2d[p].end()), edge2d[p].end());
251  }
252 
253  if (debug) {
254  for (size_t p = 0; p < planes.size(); p++) {
255  for (auto& e : edge2d[p]) {
256  std::cout << " plane=" << p << " e1=" << e.n1 << " e2=" << e.n2 << std::endl;
257  }
258  }
259  }
260  auto end_preprocess1 = std::chrono::high_resolution_clock::now();
261  std::chrono::duration<double> elapsed_preprocess1 = end_preprocess1 - start_preprocess1;
262 
263  // Nexus edges
264  auto start_preprocess2 = std::chrono::high_resolution_clock::now();
265  vector<vector<Edge>> edge3d(planes.size(), vector<Edge>());
266  for (size_t i = 0; i < spids->size(); ++i) {
267  if (hitids_u->at(i) >= 0) {
268  Edge e;
269  e.n1 = idsmapRev[hitids_u->at(i)];
270  e.n2 = spids->at(i);
271  edge3d[0].push_back(e);
272  }
273  if (hitids_v->at(i) >= 0) {
274  Edge e;
275  e.n1 = idsmapRev[hitids_v->at(i)];
276  e.n2 = spids->at(i);
277  edge3d[1].push_back(e);
278  }
279  if (hitids_y->at(i) >= 0) {
280  Edge e;
281  e.n1 = idsmapRev[hitids_y->at(i)];
282  e.n2 = spids->at(i);
283  edge3d[2].push_back(e);
284  }
285  }
286 
287  // Prepare inputs
288  auto x = torch::Dict<std::string, torch::Tensor>();
289  auto batch = torch::Dict<std::string, torch::Tensor>();
290  for (size_t p = 0; p < planes.size(); p++) {
291  vector<float> nodeft;
292  for (size_t i = 0; i < hit_plane->size(); ++i) {
293  if (size_t(hit_plane->at(i)) != p) continue;
294  nodeft.push_back((hit_wire->at(i) * pos_norm[0] - avgs[hit_plane->at(i)][0]) /
295  devs[hit_plane->at(i)][0]);
296  nodeft.push_back((hit_time->at(i) * pos_norm[1] - avgs[hit_plane->at(i)][1]) /
297  devs[hit_plane->at(i)][1]);
298  nodeft.push_back((hit_integral->at(i) - avgs[hit_plane->at(i)][2]) /
299  devs[hit_plane->at(i)][2]);
300  nodeft.push_back((hit_rms->at(i) - avgs[hit_plane->at(i)][3]) / devs[hit_plane->at(i)][3]);
301  }
302  long int dim = nodeft.size() / 4;
303  torch::Tensor ix = torch::zeros({dim, 4}, torch::dtype(torch::kFloat32));
304  if (debug) {
305  std::cout << "plane=" << p << std::endl;
306  std::cout << std::scientific;
307  for (size_t n = 0; n < nodeft.size(); n = n + 4) {
308  std::cout << nodeft[n] << " " << nodeft[n + 1] << " " << nodeft[n + 2] << " "
309  << nodeft[n + 3] << " " << std::endl;
310  }
311  }
312  for (size_t n = 0; n < nodeft.size(); n = n + 4) {
313  ix[n / 4][0] = nodeft[n];
314  ix[n / 4][1] = nodeft[n + 1];
315  ix[n / 4][2] = nodeft[n + 2];
316  ix[n / 4][3] = nodeft[n + 3];
317  }
318  x.insert(planes[p], ix);
319  torch::Tensor ib = torch::zeros({dim}, torch::dtype(torch::kInt64));
320  batch.insert(planes[p], ib);
321  }
322 
323  auto edge_index_plane = torch::Dict<std::string, torch::Tensor>();
324  for (size_t p = 0; p < planes.size(); p++) {
325  long int dim = edge2d[p].size();
326  torch::Tensor ix = torch::zeros({2, dim}, torch::dtype(torch::kInt64));
327  for (size_t n = 0; n < edge2d[p].size(); n++) {
328  ix[0][n] = int(edge2d[p][n].n1);
329  ix[1][n] = int(edge2d[p][n].n2);
330  }
331  edge_index_plane.insert(planes[p], ix);
332  if (debug) {
333  std::cout << "plane=" << p << std::endl;
334  std::cout << "2d edge size=" << edge2d[p].size() << std::endl;
335  for (size_t n = 0; n < edge2d[p].size(); n++) {
336  std::cout << edge2d[p][n].n1 << " ";
337  }
338  std::cout << std::endl;
339  for (size_t n = 0; n < edge2d[p].size(); n++) {
340  std::cout << edge2d[p][n].n2 << " ";
341  }
342  std::cout << std::endl;
343  }
344  }
345 
346  auto edge_index_nexus = torch::Dict<std::string, torch::Tensor>();
347  for (size_t p = 0; p < planes.size(); p++) {
348  long int dim = edge3d[p].size();
349  torch::Tensor ix = torch::zeros({2, dim}, torch::dtype(torch::kInt64));
350  for (size_t n = 0; n < edge3d[p].size(); n++) {
351  ix[0][n] = int(edge3d[p][n].n1);
352  ix[1][n] = int(edge3d[p][n].n2);
353  }
354  edge_index_nexus.insert(planes[p], ix);
355  if (debug) {
356  std::cout << "plane=" << p << std::endl;
357  std::cout << "3d edge size=" << edge3d[p].size() << std::endl;
358  for (size_t n = 0; n < edge3d[p].size(); n++) {
359  std::cout << edge3d[p][n].n1 << " ";
360  }
361  std::cout << std::endl;
362  for (size_t n = 0; n < edge3d[p].size(); n++) {
363  std::cout << edge3d[p][n].n2 << " ";
364  }
365  std::cout << std::endl;
366  }
367  }
368 
369  long int spdim = spids->size();
370  auto nexus = torch::empty({spdim, 0}, torch::dtype(torch::kFloat32));
371 
372  std::vector<torch::jit::IValue> inputs;
373  inputs.push_back(x);
374  inputs.push_back(edge_index_plane);
375  inputs.push_back(edge_index_nexus);
376  inputs.push_back(nexus);
377  inputs.push_back(batch);
378 
379  // Run inference
380  auto end_preprocess2 = std::chrono::high_resolution_clock::now();
381  std::chrono::duration<double> elapsed_preprocess2 = end_preprocess2 - start_preprocess2;
382  if (debug) std::cout << "FORWARD!" << std::endl;
383  auto start = std::chrono::high_resolution_clock::now();
384  auto outputs = model.forward(inputs).toGenericDict();
385  auto end = std::chrono::high_resolution_clock::now();
386  std::chrono::duration<double> elapsed = end - start;
387  if (debug) {
388  std::cout << "Time taken for inference: "
389  << elapsed_preprocess1.count() + elapsed_preprocess2.count() + elapsed.count()
390  << " seconds" << std::endl;
391  std::cout << "output =" << outputs << std::endl;
392  }
393 
394  //
395  // Get pointers to the result returned and write to the event
396  //
397  vector<NuGraphOutput> infer_output;
398  for (const auto& elem1 : outputs) {
399  if (elem1.value().isTensor()) {
400  torch::Tensor tensor = elem1.value().toTensor();
401  std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
402  infer_output.push_back(NuGraphOutput(elem1.key().to<std::string>(), vec));
403  }
404  else if (elem1.value().isGenericDict()) {
405  for (const auto& elem2 : elem1.value().toGenericDict()) {
406  torch::Tensor tensor = elem2.value().toTensor();
407  std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
408  infer_output.push_back(
409  NuGraphOutput(elem1.key().to<std::string>() + "_" + elem2.key().to<std::string>(), vec));
410  }
411  }
412  }
413 
414  // Write the outputs to the output root file
415  for (size_t i = 0; i < _decoderToolsVec.size(); i++) {
416  _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output);
417  }
418 }
Float_t x
Definition: compare.C:6
torch::jit::script::Module model
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
vector< std::string > planes
boost::graph_traits< ModuleGraph >::edge_descriptor Edge
Definition: ModuleGraph.h:24
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
Definition: StdUtils.h:77
vector< vector< float > > devs
Float_t d
Definition: plot.C:235
vector< vector< float > > avgs
Char_t n[5]
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
Float_t e
Definition: plot.C:35
bool operator==(infinite_endcount_iterator< T > const &, count_iterator< T > const &)
Definition: counter.h:277
decltype(auto) constexpr empty(T &&obj)
ADL-aware version of std::empty.
Definition: StdUtils.h:109
std::unique_ptr< LoaderToolBase > _loaderTool
void art::Modifier::registerProducts ( ProductDescriptions productsToRegister)
inherited

Definition at line 16 of file Modifier.cc.

References art::ModuleBase::moduleDescription(), and art::ProductRegistryHelper::registerProducts().

17  {
18  ProductRegistryHelper::registerProducts(productsToRegister,
20  }
void registerProducts(ProductDescriptions &productsToRegister, ModuleDescription const &md)
ModuleDescription const & moduleDescription() const
Definition: ModuleBase.cc:13
void art::ModuleBase::setModuleDescription ( ModuleDescription const &  md)
inherited

Definition at line 31 of file ModuleBase.cc.

References art::ModuleBase::md_.

32  {
33  md_ = md;
34  }
std::optional< ModuleDescription > md_
Definition: ModuleBase.h:55
void art::ModuleBase::sortConsumables ( std::string const &  current_process_name)
inherited

Definition at line 49 of file ModuleBase.cc.

References art::ModuleBase::collector_, and art::ConsumesCollector::sortConsumables().

50  {
51  // Now that we know we have seen all the consumes declarations,
52  // sort the results for fast lookup later.
53  collector_.sortConsumables(current_process_name);
54  }
ConsumesCollector collector_
Definition: ModuleBase.h:56
void sortConsumables(std::string const &current_process_name)

Member Data Documentation

std::vector<std::unique_ptr<DecoderToolBase> > NuGraphInference::_decoderToolsVec
private

Definition at line 95 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

std::unique_ptr<LoaderToolBase> NuGraphInference::_loaderTool
private

Definition at line 93 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

vector<vector<float> > NuGraphInference::avgs
private

Definition at line 88 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

bool NuGraphInference::debug
private

Definition at line 87 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

vector<vector<float> > NuGraphInference::devs
private

Definition at line 89 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

size_t NuGraphInference::minHits
private

Definition at line 86 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

torch::jit::script::Module NuGraphInference::model
private

Definition at line 91 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

vector<std::string> NuGraphInference::planes
private

Definition at line 85 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().

vector<float> NuGraphInference::pos_norm
private

Definition at line 90 of file NuGraphInference_module.cc.

Referenced by NuGraphInference(), and produce().


The documentation for this class was generated from the following file: