LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
PointIdAlg.h
Go to the documentation of this file.
1 // Class: PointIdAlg
3 // Authors: D.Stefan (Dorota.Stefan@ncbj.gov.pl), from DUNE, CERN/NCBJ, since May 2016
4 // R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, since May 2016
5 // P.Plonski, from DUNE, WUT, since May 2016
6 //
7 //
8 // Point Identification Algorithm
9 //
10 // Run CNN or MLP trained to classify a point in 2D projection. Various features can be
11 // recognized, depending on the net model/weights used.
12 //
14 
15 #ifndef PointIdAlg_h
16 #define PointIdAlg_h
17 
18 // LArSoft includes
19 namespace detinfo {
20  class DetectorClocksData;
21  class DetectorPropertiesData;
22 }
23 
28 
29 // Framework includes
32 #include "fhiclcpp/fwd.h"
33 #include "fhiclcpp/types/Atom.h"
34 #include "fhiclcpp/types/Comment.h"
35 #include "fhiclcpp/types/Name.h"
37 #include "fhiclcpp/types/Table.h"
38 
39 // ROOT & C++
40 #include <memory>
41 #include <string>
42 #include <unordered_map>
43 #include <vector>
44 
45 namespace nnet {
46  class ModelInterface;
47  class KerasModelInterface;
48  class TfModelInterface;
49  class PointIdAlg;
50  class TrainingDataAlg;
51 }
52 
57 public:
58  virtual ~ModelInterface() {}
59 
60  virtual std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) = 0;
61  virtual std::vector<std::vector<float>> Run(
62  std::vector<std::vector<std::vector<float>>> const& inps,
63  int samples = -1);
64 
65 protected:
66  std::string findFile(const char* fileName) const;
67 };
68 // ------------------------------------------------------
69 
71 public:
72  KerasModelInterface(const char* modelFileName);
73 
74  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) override;
75 
76 private:
77  keras::KerasModel m; // network model
78 };
79 // ------------------------------------------------------
80 
82 public:
83  TfModelInterface(const char* modelFileName);
84 
85  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
86  int samples = -1) override;
87  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) override;
88 
89 private:
90  std::unique_ptr<tf::Graph> g; // network graph
91 };
92 // ------------------------------------------------------
93 
95 public:
97  using Name = fhicl::Name;
99 
100  fhicl::Atom<std::string> NNetModelFile{Name("NNetModelFile"),
101  Comment("Neural net model to apply.")};
102  fhicl::Sequence<std::string> NNetOutputs{Name("NNetOutputs"),
103  Comment("Labels of the network outputs.")};
104  fhicl::Atom<unsigned int> PatchSizeW{Name("PatchSizeW"), Comment("How many wires in patch.")};
105 
106  fhicl::Atom<unsigned int> PatchSizeD{Name("PatchSizeD"),
107  Comment("How many downsampled ADC entries in patch")};
108  };
109 
110  PointIdAlg(const fhicl::ParameterSet& pset) : PointIdAlg(fhicl::Table<Config>(pset, {})()) {}
111 
112  PointIdAlg(const Config& config);
113 
114  ~PointIdAlg() override;
115 
117  std::vector<std::string> const& outputLabels() const { return fNNetOutputs; }
118 
120  float predictIdValue(unsigned int wire, float drift, size_t outIdx = 0) const;
121 
123  std::vector<float> predictIdVector(unsigned int wire, float drift) const;
124 
125  std::vector<std::vector<float>> predictIdVectors(
126  std::vector<std::pair<unsigned int, float>> points) const;
127 
128  static std::vector<float> flattenData2D(std::vector<std::vector<float>> const& patch);
129 
130  std::vector<std::vector<float>> const& patchData2D() const { return fWireDriftPatch; }
131  std::vector<float> patchData1D() const
132  {
133  return flattenData2D(fWireDriftPatch);
134  } // flat vector made of the patch data, wire after wire
135 
136  bool isInsideFiducialRegion(unsigned int wire, float drift) const;
137 
140  bool isCurrentPatch(unsigned int wire, float drift) const;
141 
143  bool isSamePatch(unsigned int wire1, float drift1, unsigned int wire2, float drift2) const;
144 
145 private:
146  std::string fNNetModelFilePath;
147  std::vector<std::string> fNNetOutputs;
149 
150  mutable std::vector<std::vector<float>> fWireDriftPatch; // patch data around the identified point
151  size_t fPatchSizeW, fPatchSizeD;
152 
153  mutable size_t fCurrentWireIdx, fCurrentScaledDrift;
154  bool bufferPatch(size_t wire, float drift, std::vector<std::vector<float>>& patch) const
155  {
156  if (fDownscaleFullView) {
157  size_t sd = (size_t)(drift / fDriftWindow);
158  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == sd))
159  return true; // still within the current position
160 
161  fCurrentWireIdx = wire;
162  fCurrentScaledDrift = sd;
163 
164  return patchFromDownsampledView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
165  }
166  else {
167  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == drift))
168  return true; // still within the current position
169 
170  fCurrentWireIdx = wire;
171  fCurrentScaledDrift = drift;
172 
173  return patchFromOriginalView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
174  }
175  }
176  bool bufferPatch(size_t wire, float drift) const
177  {
178  return bufferPatch(wire, drift, fWireDriftPatch);
179  }
180  void resizePatch();
181 
182  void deleteNNet()
183  {
184  if (fNNet) delete fNNet;
185  fNNet = 0;
186  }
187 };
188 // ------------------------------------------------------
189 // ------------------------------------------------------
190 // ------------------------------------------------------
191 
193 public:
194  enum EMask {
195  kNone = 0,
196  kPdgMask = 0x00000FFF, // pdg code mask
197  kTypeMask = 0x0000F000, // track type mask
198  kVtxMask = 0xFFFF0000 // vertex flags
199  };
200 
201  enum ETrkType {
202  kDelta = 0x1000, // delta electron
203  kMichel = 0x2000, // Michel electron
204  kPriEl = 0x4000, // primary electron
205  kPriMu = 0x8000 // primary muon
206  };
207 
208  enum EVtxId {
209  kNuNC = 0x0010000,
210  kNuCC = 0x0020000,
211  kNuPri = 0x0040000, // nu interaction type
212  kNuE = 0x0100000,
213  kNuMu = 0x0200000,
214  kNuTau = 0x0400000, // nu flavor
215  kHadr = 0x1000000, // hadronic inelastic scattering
216  kPi0 = 0x2000000, // pi0 produced in this vertex
217  kDecay = 0x4000000, // point of particle decay
218  kConv = 0x8000000, // gamma conversion
219  kElectronEnd = 0x10000000, // clear end of an electron
220  kElastic = 0x20000000, // Elastic scattering
221  kInelastic = 0x40000000 // Inelastic scattering
222  };
223 
225  using Name = fhicl::Name;
227 
228  fhicl::Atom<art::InputTag> WireLabel{Name("WireLabel"), Comment("Tag of recob::Wire.")};
229 
230  fhicl::Atom<art::InputTag> HitLabel{Name("HitLabel"), Comment("Tag of recob::Hit.")};
231 
232  fhicl::Atom<art::InputTag> TrackLabel{Name("TrackLabel"), Comment("Tag of recob::Track.")};
233 
234  fhicl::Atom<art::InputTag> SimulationLabel{Name("SimulationLabel"),
235  Comment("Tag of simulation producer.")};
236 
237  fhicl::Atom<art::InputTag> SimChannelLabel{Name("SimChannelLabel"),
238  Comment("Tag of sim::SimChannel producer.")};
239 
240  fhicl::Atom<bool> SaveVtxFlags{Name("SaveVtxFlags"),
241  Comment("Include (or not) vertex info in PDG map.")};
242 
244  Name("AdcDelayTicks"),
245  Comment("ADC pulse peak delay in ticks (non-zero for not deconvoluted waveforms).")};
246  };
247 
249  : TrainingDataAlg(fhicl::Table<Config>(pset, {})())
250  {}
251 
252  TrainingDataAlg(const Config& config);
253 
254  ~TrainingDataAlg() override;
255 
256  void reconfigure(const Config& config);
257 
258  bool saveSimInfo() const { return fSaveSimInfo; }
259 
260  bool setEventData(
261  const art::Event& event, // collect & downscale ADC's, charge deposits, pdg labels
262  detinfo::DetectorClocksData const& clockData,
263  detinfo::DetectorPropertiesData const& detProp,
264  unsigned int plane,
265  unsigned int tpc,
266  unsigned int cryo);
267 
268  bool setDataEventData(
269  const art::Event& event, // collect & downscale ADC's, charge deposits, pdg labels
270  detinfo::DetectorClocksData const& clockData,
271  detinfo::DetectorPropertiesData const& detProp,
272  unsigned int plane,
273  unsigned int tpc,
274  unsigned int cryo);
275 
276  bool findCrop(float max_e_cut,
277  unsigned int& w0,
278  unsigned int& w1,
279  unsigned int& d0,
280  unsigned int& d1) const;
281 
282  double getEdepTot() const { return fEdepTot; } // [GeV]
283  std::vector<float> const& wireEdep(size_t widx) const { return fWireDriftEdep[widx]; }
284  std::vector<int> const& wirePdg(size_t widx) const { return fWireDriftPdg[widx]; }
285 
286 protected:
287  img::DataProviderAlgView resizeView(detinfo::DetectorClocksData const& clock_data,
288  detinfo::DetectorPropertiesData const& det_prop,
289  size_t wires,
290  size_t drifts) override;
291 
292 private:
293  struct WireDrift // used to find MCParticle start/end 2D projections
294  {
295  size_t Wire;
296  int Drift;
297  unsigned int TPC;
298  unsigned int Cryo;
299  };
300 
301  WireDrift getProjection(detinfo::DetectorClocksData const& clockData,
302  detinfo::DetectorPropertiesData const& detProp,
303  const TLorentzVector& tvec,
304  unsigned int plane) const;
305 
306  bool setWireEdepsAndLabels(std::vector<float> const& edeps,
307  std::vector<int> const& pdgs,
308  size_t wireIdx);
309 
310  void collectVtxFlags(
311  std::unordered_map<size_t, std::unordered_map<int, int>>& wireToDriftToVtxFlags,
312  detinfo::DetectorClocksData const& clockData,
313  detinfo::DetectorPropertiesData const& detProp,
314  const std::unordered_map<int, const simb::MCParticle*>& particleMap,
315  unsigned int plane) const;
316 
317  static float particleRange2(const simb::MCParticle& particle)
318  {
319  float dx = particle.EndX() - particle.Vx();
320  float dy = particle.EndY() - particle.Vy();
321  float dz = particle.EndZ() - particle.Vz();
322  return dx * dx + dy * dy + dz * dz;
323  }
324  bool isElectronEnd(const simb::MCParticle& particle,
325  const std::unordered_map<int, const simb::MCParticle*>& particleMap) const;
326 
327  bool isMuonDecaying(const simb::MCParticle& particle,
328  const std::unordered_map<int, const simb::MCParticle*>& particleMap) const;
329 
330  double fEdepTot; // [GeV]
331  std::vector<std::vector<float>> fWireDriftEdep;
332  std::vector<std::vector<int>> fWireDriftPdg;
333 
341 
342  unsigned int fAdcDelay;
343 
344  std::vector<size_t> fEventsPerBin;
345 };
346 // ------------------------------------------------------
347 // ------------------------------------------------------
348 // ------------------------------------------------------
349 
350 #endif
std::vector< std::string > const & outputLabels() const
network output labels
Definition: PointIdAlg.h:117
std::vector< float > const & wireEdep(size_t widx) const
Definition: PointIdAlg.h:283
double EndZ() const
Definition: MCParticle.h:229
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:150
bool saveSimInfo() const
Definition: PointIdAlg.h:258
unsigned int fAdcDelay
Definition: PointIdAlg.h:342
bool bufferPatch(size_t wire, float drift) const
Definition: PointIdAlg.h:176
art::InputTag fTrackModuleLabel
Definition: PointIdAlg.h:336
bool bufferPatch(size_t wire, float drift, std::vector< std::vector< float >> &patch) const
Definition: PointIdAlg.h:154
static const int kNuTau
Particle class.
double EndY() const
Definition: MCParticle.h:228
no compression
Definition: RawTypes.h:9
double getEdepTot() const
Definition: PointIdAlg.h:282
std::vector< std::string > fNNetOutputs
Definition: PointIdAlg.h:147
std::vector< float > patchData1D() const
Definition: PointIdAlg.h:131
virtual ~ModelInterface()
Definition: PointIdAlg.h:58
art::InputTag fWireProducerLabel
Definition: PointIdAlg.h:334
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
std::vector< std::vector< float > > const & patchData2D() const
Definition: PointIdAlg.h:130
parameter set interface
Definition: EmTrack.h:40
std::vector< size_t > fEventsPerBin
Definition: PointIdAlg.h:344
art::InputTag fSimChannelProducerLabel
Definition: PointIdAlg.h:338
General LArSoft Utilities.
keras::KerasModel m
Definition: PointIdAlg.h:77
std::vector< std::vector< int > > fWireDriftPdg
Definition: PointIdAlg.h:332
double Vx(const int i=0) const
Definition: MCParticle.h:222
std::vector< std::vector< float > > fWireDriftEdep
Definition: PointIdAlg.h:331
Contains all timing reference information for the detector.
static float particleRange2(const simb::MCParticle &particle)
Definition: PointIdAlg.h:317
double Vz(const int i=0) const
Definition: MCParticle.h:224
nnet::ModelInterface * fNNet
Definition: PointIdAlg.h:148
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:90
size_t fPatchSizeW
Definition: PointIdAlg.h:151
art::InputTag fSimulationProducerLabel
Definition: PointIdAlg.h:337
std::string fNNetModelFilePath
Definition: PointIdAlg.h:146
std::vector< int > const & wirePdg(size_t widx) const
Definition: PointIdAlg.h:284
double EndX() const
Definition: MCParticle.h:227
art::InputTag fHitProducerLabel
Definition: PointIdAlg.h:335
double Vy(const int i=0) const
Definition: MCParticle.h:223
static const int kNuMu
Event finding and building.
TrainingDataAlg(const fhicl::ParameterSet &pset)
Definition: PointIdAlg.h:248
size_t fCurrentWireIdx
Definition: PointIdAlg.h:153
map< int, array< map< int, double >, 2 >> Table
Definition: plot.C:18
PointIdAlg(const fhicl::ParameterSet &pset)
Definition: PointIdAlg.h:110