LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
PointIdAlg.h
Go to the documentation of this file.
1 // Class: PointIdAlg
3 // Authors: D.Stefan (Dorota.Stefan@ncbj.gov.pl), from DUNE, CERN/NCBJ, since May 2016
4 // R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, since May 2016
5 // P.Plonski, from DUNE, WUT, since May 2016
6 //
7 //
8 // Point Identification Algorithm
9 //
10 // Run CNN or MLP trained to classify a point in 2D projection. Various features can be
11 // recognized, depending on the net model/weights used.
12 //
14 
15 #ifndef PointIdAlg_h
16 #define PointIdAlg_h
17 
18 // LArSoft includes
19 namespace detinfo {
20  class DetectorClocksData;
21  class DetectorPropertiesData;
22 }
23 
28 
29 // Framework includes
30 namespace art {
31  class Event;
32 }
33 namespace fhicl {
34  class ParameterSet;
35 }
37 #include "fhiclcpp/types/Atom.h"
38 #include "fhiclcpp/types/Comment.h"
39 #include "fhiclcpp/types/Name.h"
41 #include "fhiclcpp/types/Table.h"
42 
43 // ROOT & C++
44 #include <memory>
45 #include <string>
46 #include <unordered_map>
47 #include <vector>
48 
49 namespace nnet {
50  class ModelInterface;
51  class KerasModelInterface;
52  class TfModelInterface;
53  class PointIdAlg;
54  class TrainingDataAlg;
55 }
56 
61 public:
62  virtual ~ModelInterface() {}
63 
64  virtual std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) = 0;
65  virtual std::vector<std::vector<float>> Run(
66  std::vector<std::vector<std::vector<float>>> const& inps,
67  int samples = -1);
68 
69 protected:
70  std::string findFile(const char* fileName) const;
71 };
72 // ------------------------------------------------------
73 
75 public:
76  KerasModelInterface(const char* modelFileName);
77 
78  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) override;
79 
80 private:
81  keras::KerasModel m; // network model
82 };
83 // ------------------------------------------------------
84 
86 public:
87  TfModelInterface(const char* modelFileName);
88 
89  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
90  int samples = -1) override;
91  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) override;
92 
93 private:
94  std::unique_ptr<tf::Graph> g; // network graph
95 };
96 // ------------------------------------------------------
97 
99 public:
101  using Name = fhicl::Name;
103 
104  fhicl::Atom<std::string> NNetModelFile{Name("NNetModelFile"),
105  Comment("Neural net model to apply.")};
106  fhicl::Sequence<std::string> NNetOutputs{Name("NNetOutputs"),
107  Comment("Labels of the network outputs.")};
108  fhicl::Atom<unsigned int> PatchSizeW{Name("PatchSizeW"), Comment("How many wires in patch.")};
109 
110  fhicl::Atom<unsigned int> PatchSizeD{Name("PatchSizeD"),
111  Comment("How many downsampled ADC entries in patch")};
112  };
113 
114  PointIdAlg(const fhicl::ParameterSet& pset) : PointIdAlg(fhicl::Table<Config>(pset, {})()) {}
115 
116  PointIdAlg(const Config& config);
117 
118  ~PointIdAlg() override;
119 
121  std::vector<std::string> const& outputLabels() const { return fNNetOutputs; }
122 
124  float predictIdValue(unsigned int wire, float drift, size_t outIdx = 0) const;
125 
127  std::vector<float> predictIdVector(unsigned int wire, float drift) const;
128 
129  std::vector<std::vector<float>> predictIdVectors(
130  std::vector<std::pair<unsigned int, float>> points) const;
131 
132  static std::vector<float> flattenData2D(std::vector<std::vector<float>> const& patch);
133 
134  std::vector<std::vector<float>> const& patchData2D() const { return fWireDriftPatch; }
135  std::vector<float> patchData1D() const
136  {
137  return flattenData2D(fWireDriftPatch);
138  } // flat vector made of the patch data, wire after wire
139 
140  bool isInsideFiducialRegion(unsigned int wire, float drift) const;
141 
144  bool isCurrentPatch(unsigned int wire, float drift) const;
145 
147  bool isSamePatch(unsigned int wire1, float drift1, unsigned int wire2, float drift2) const;
148 
149 private:
150  std::string fNNetModelFilePath;
151  std::vector<std::string> fNNetOutputs;
153 
154  mutable std::vector<std::vector<float>> fWireDriftPatch; // patch data around the identified point
155  size_t fPatchSizeW, fPatchSizeD;
156 
157  mutable size_t fCurrentWireIdx, fCurrentScaledDrift;
158  bool bufferPatch(size_t wire, float drift, std::vector<std::vector<float>>& patch) const
159  {
160  if (fDownscaleFullView) {
161  size_t sd = (size_t)(drift / fDriftWindow);
162  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == sd))
163  return true; // still within the current position
164 
165  fCurrentWireIdx = wire;
166  fCurrentScaledDrift = sd;
167 
168  return patchFromDownsampledView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
169  }
170  else {
171  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == drift))
172  return true; // still within the current position
173 
174  fCurrentWireIdx = wire;
175  fCurrentScaledDrift = drift;
176 
177  return patchFromOriginalView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
178  }
179  }
180  bool bufferPatch(size_t wire, float drift) const
181  {
182  return bufferPatch(wire, drift, fWireDriftPatch);
183  }
184  void resizePatch();
185 
186  void deleteNNet()
187  {
188  if (fNNet) delete fNNet;
189  fNNet = 0;
190  }
191 };
192 // ------------------------------------------------------
193 // ------------------------------------------------------
194 // ------------------------------------------------------
195 
197 public:
198  enum EMask {
199  kNone = 0,
200  kPdgMask = 0x00000FFF, // pdg code mask
201  kTypeMask = 0x0000F000, // track type mask
202  kVtxMask = 0xFFFF0000 // vertex flags
203  };
204 
205  enum ETrkType {
206  kDelta = 0x1000, // delta electron
207  kMichel = 0x2000, // Michel electron
208  kPriEl = 0x4000, // primary electron
209  kPriMu = 0x8000 // primary muon
210  };
211 
212  enum EVtxId {
213  kNuNC = 0x0010000,
214  kNuCC = 0x0020000,
215  kNuPri = 0x0040000, // nu interaction type
216  kNuE = 0x0100000,
217  kNuMu = 0x0200000,
218  kNuTau = 0x0400000, // nu flavor
219  kHadr = 0x1000000, // hadronic inelastic scattering
220  kPi0 = 0x2000000, // pi0 produced in this vertex
221  kDecay = 0x4000000, // point of particle decay
222  kConv = 0x8000000, // gamma conversion
223  kElectronEnd = 0x10000000, // clear end of an electron
224  kElastic = 0x20000000, // Elastic scattering
225  kInelastic = 0x40000000 // Inelastic scattering
226  };
227 
229  using Name = fhicl::Name;
231 
232  fhicl::Atom<art::InputTag> WireLabel{Name("WireLabel"), Comment("Tag of recob::Wire.")};
233 
234  fhicl::Atom<art::InputTag> HitLabel{Name("HitLabel"), Comment("Tag of recob::Hit.")};
235 
236  fhicl::Atom<art::InputTag> TrackLabel{Name("TrackLabel"), Comment("Tag of recob::Track.")};
237 
238  fhicl::Atom<art::InputTag> SimulationLabel{Name("SimulationLabel"),
239  Comment("Tag of simulation producer.")};
240 
241  fhicl::Atom<art::InputTag> SimChannelLabel{Name("SimChannelLabel"),
242  Comment("Tag of sim::SimChannel producer.")};
243 
244  fhicl::Atom<bool> SaveVtxFlags{Name("SaveVtxFlags"),
245  Comment("Include (or not) vertex info in PDG map.")};
246 
248  Name("AdcDelayTicks"),
249  Comment("ADC pulse peak delay in ticks (non-zero for not deconvoluted waveforms).")};
250  };
251 
253  : TrainingDataAlg(fhicl::Table<Config>(pset, {})())
254  {}
255 
256  TrainingDataAlg(const Config& config);
257 
258  ~TrainingDataAlg() override;
259 
260  void reconfigure(const Config& config);
261 
262  bool saveSimInfo() const { return fSaveSimInfo; }
263 
264  bool setEventData(
265  const art::Event& event, // collect & downscale ADC's, charge deposits, pdg labels
266  detinfo::DetectorClocksData const& clockData,
267  detinfo::DetectorPropertiesData const& detProp,
268  unsigned int plane,
269  unsigned int tpc,
270  unsigned int cryo);
271 
272  bool setDataEventData(
273  const art::Event& event, // collect & downscale ADC's, charge deposits, pdg labels
274  detinfo::DetectorClocksData const& clockData,
275  detinfo::DetectorPropertiesData const& detProp,
276  unsigned int plane,
277  unsigned int tpc,
278  unsigned int cryo);
279 
280  bool findCrop(float max_e_cut,
281  unsigned int& w0,
282  unsigned int& w1,
283  unsigned int& d0,
284  unsigned int& d1) const;
285 
286  double getEdepTot() const { return fEdepTot; } // [GeV]
287  std::vector<float> const& wireEdep(size_t widx) const { return fWireDriftEdep[widx]; }
288  std::vector<int> const& wirePdg(size_t widx) const { return fWireDriftPdg[widx]; }
289 
290 protected:
291  img::DataProviderAlgView resizeView(detinfo::DetectorClocksData const& clock_data,
292  detinfo::DetectorPropertiesData const& det_prop,
293  size_t wires,
294  size_t drifts) override;
295 
296 private:
297  struct WireDrift // used to find MCParticle start/end 2D projections
298  {
299  size_t Wire;
300  int Drift;
301  unsigned int TPC;
302  unsigned int Cryo;
303  };
304 
305  WireDrift getProjection(detinfo::DetectorClocksData const& clockData,
306  detinfo::DetectorPropertiesData const& detProp,
307  const TLorentzVector& tvec,
308  unsigned int plane) const;
309 
310  bool setWireEdepsAndLabels(std::vector<float> const& edeps,
311  std::vector<int> const& pdgs,
312  size_t wireIdx);
313 
314  void collectVtxFlags(
315  std::unordered_map<size_t, std::unordered_map<int, int>>& wireToDriftToVtxFlags,
316  detinfo::DetectorClocksData const& clockData,
317  detinfo::DetectorPropertiesData const& detProp,
318  const std::unordered_map<int, const simb::MCParticle*>& particleMap,
319  unsigned int plane) const;
320 
321  static float particleRange2(const simb::MCParticle& particle)
322  {
323  float dx = particle.EndX() - particle.Vx();
324  float dy = particle.EndY() - particle.Vy();
325  float dz = particle.EndZ() - particle.Vz();
326  return dx * dx + dy * dy + dz * dz;
327  }
328  bool isElectronEnd(const simb::MCParticle& particle,
329  const std::unordered_map<int, const simb::MCParticle*>& particleMap) const;
330 
331  bool isMuonDecaying(const simb::MCParticle& particle,
332  const std::unordered_map<int, const simb::MCParticle*>& particleMap) const;
333 
334  double fEdepTot; // [GeV]
335  std::vector<std::vector<float>> fWireDriftEdep;
336  std::vector<std::vector<int>> fWireDriftPdg;
337 
345 
346  unsigned int fAdcDelay;
347 
348  std::vector<size_t> fEventsPerBin;
349 };
350 // ------------------------------------------------------
351 // ------------------------------------------------------
352 // ------------------------------------------------------
353 
354 #endif
std::vector< std::string > const & outputLabels() const
network output labels
Definition: PointIdAlg.h:121
std::vector< float > const & wireEdep(size_t widx) const
Definition: PointIdAlg.h:287
double EndZ() const
Definition: MCParticle.h:229
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:154
bool saveSimInfo() const
Definition: PointIdAlg.h:262
unsigned int fAdcDelay
Definition: PointIdAlg.h:346
bool bufferPatch(size_t wire, float drift) const
Definition: PointIdAlg.h:180
art::InputTag fTrackModuleLabel
Definition: PointIdAlg.h:340
bool bufferPatch(size_t wire, float drift, std::vector< std::vector< float >> &patch) const
Definition: PointIdAlg.h:158
static const int kNuTau
Particle class.
double EndY() const
Definition: MCParticle.h:228
no compression
Definition: RawTypes.h:9
double getEdepTot() const
Definition: PointIdAlg.h:286
std::vector< std::string > fNNetOutputs
Definition: PointIdAlg.h:151
std::vector< float > patchData1D() const
Definition: PointIdAlg.h:135
virtual ~ModelInterface()
Definition: PointIdAlg.h:62
art::InputTag fWireProducerLabel
Definition: PointIdAlg.h:338
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
std::vector< std::vector< float > > const & patchData2D() const
Definition: PointIdAlg.h:134
parameter set interface
Definition: EmTrack.h:40
std::vector< size_t > fEventsPerBin
Definition: PointIdAlg.h:348
art::InputTag fSimChannelProducerLabel
Definition: PointIdAlg.h:342
General LArSoft Utilities.
keras::KerasModel m
Definition: PointIdAlg.h:81
std::vector< std::vector< int > > fWireDriftPdg
Definition: PointIdAlg.h:336
double Vx(const int i=0) const
Definition: MCParticle.h:222
std::vector< std::vector< float > > fWireDriftEdep
Definition: PointIdAlg.h:335
Contains all timing reference information for the detector.
static float particleRange2(const simb::MCParticle &particle)
Definition: PointIdAlg.h:321
double Vz(const int i=0) const
Definition: MCParticle.h:224
Definition: MVAAlg.h:12
nnet::ModelInterface * fNNet
Definition: PointIdAlg.h:152
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:94
size_t fPatchSizeW
Definition: PointIdAlg.h:155
art::InputTag fSimulationProducerLabel
Definition: PointIdAlg.h:341
std::string fNNetModelFilePath
Definition: PointIdAlg.h:150
std::vector< int > const & wirePdg(size_t widx) const
Definition: PointIdAlg.h:288
double EndX() const
Definition: MCParticle.h:227
art::InputTag fHitProducerLabel
Definition: PointIdAlg.h:339
double Vy(const int i=0) const
Definition: MCParticle.h:223
static const int kNuMu
Event finding and building.
TrainingDataAlg(const fhicl::ParameterSet &pset)
Definition: PointIdAlg.h:252
size_t fCurrentWireIdx
Definition: PointIdAlg.h:157
map< int, array< map< int, double >, 2 >> Table
Definition: plot.C:18
PointIdAlg(const fhicl::ParameterSet &pset)
Definition: PointIdAlg.h:114