LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
PointIdAlgKeras_tool.cc
Go to the documentation of this file.
1 // Class: PointIdAlgKeras_tool (tool version of Keras model interface in PointIdAlg)
3 // Authors: D.Stefan (Dorota.Stefan@ncbj.gov.pl), from DUNE, CERN/NCBJ, since May 2016
4 // R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, since May 2016
5 // P.Plonski, from DUNE, WUT, since May 2016
6 // D.Smith, from LArIAT, BU, 2017: real data dump
7 // M.Wang, from DUNE, FNAL, 2020: tool version
9 
11 
13 #include "larrecodnn/ImagePatternAlgs/ToolInterfaces/IPointIdAlg.h"
14 
15 #include <sys/stat.h>
16 
17 namespace PointIdAlgTools {
18 
19  class PointIdAlgKeras : public IPointIdAlg {
20  public:
21  explicit PointIdAlgKeras(const fhicl::ParameterSet& pset)
22  : PointIdAlgKeras(fhicl::Table<Config>(pset, {})())
23  {}
24  explicit PointIdAlgKeras(const Config& config);
25 
26  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) const override;
27  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
28  int samples = -1) const override;
29 
30  private:
31  std::unique_ptr<keras::KerasModel> m;
32  std::string fNNetModelFilePath;
33  std::string findFile(const char* fileName) const;
34  };
35 
36  // ------------------------------------------------------
37  PointIdAlgKeras::PointIdAlgKeras(const Config& config) : img::DataProviderAlg(config)
38  {
39  // ... Get common config vars
40  fNNetOutputs = config.NNetOutputs();
41  fPatchSizeW = config.PatchSizeW();
42  fPatchSizeD = config.PatchSizeD();
43  fCurrentWireIdx = 99999;
44  fCurrentScaledDrift = 99999;
45 
46  // ... Get "optional" config vars specific to tf interface
47  std::string s_cfgvr;
48  if (config.NNetModelFile(s_cfgvr)) { fNNetModelFilePath = s_cfgvr; }
49  else {
50  fNNetModelFilePath = "mycnn";
51  }
52 
53  if ((fNNetModelFilePath.length() > 5) &&
54  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 5, 5, ".nnet") == 0)) {
55  m = std::make_unique<keras::KerasModel>(findFile(fNNetModelFilePath.c_str()).c_str());
56  mf::LogInfo("PointIdAlgKeras") << "Keras model loaded.";
57  }
58  else {
59  mf::LogError("PointIdAlgKeras") << "File name extension not supported.";
60  }
61 
62  resizePatch();
63  }
64 
65  // ------------------------------------------------------
66  std::string PointIdAlgKeras::findFile(const char* fileName) const
67  {
68  std::string fname_out;
69  cet::search_path sp("FW_SEARCH_PATH");
70  if (!sp.find_file(fileName, fname_out)) {
71  struct stat buffer;
72  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
73  else {
74  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
75  }
76  }
77  return fname_out;
78  }
79 
80  // ------------------------------------------------------
81  std::vector<float> PointIdAlgKeras::Run(std::vector<std::vector<float>> const& inp2d) const
82  {
83  std::vector<std::vector<std::vector<float>>> inp3d;
84  inp3d.push_back(inp2d); // lots of copy, should add 2D to keras...
85 
86  keras::DataChunk2D sample;
87  sample.set_data(inp3d);
88  return m->compute_output(&sample);
89  }
90 
91  // ------------------------------------------------------
92  std::vector<std::vector<float>> PointIdAlgKeras::Run(
93  std::vector<std::vector<std::vector<float>>> const& inps,
94  int samples) const
95  {
96 
97  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
98  return std::vector<std::vector<float>>();
99  }
100 
101  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
102 
103  std::vector<std::vector<float>> out;
104 
105  for (long long int s = 0; s < samples; ++s) {
106  std::vector<std::vector<std::vector<float>>> inp3d;
107  inp3d.push_back(inps[s]); // lots of copy, should add 2D to keras...
108 
109  keras::DataChunk* sample = new keras::DataChunk2D();
110  sample->set_data(inp3d); // and more copy...
111  out.push_back(m->compute_output(sample));
112  delete sample;
113  }
114 
115  return out;
116  }
117 
118 }
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
virtual void set_data(std::vector< std::vector< std::vector< float >>> const &)
Definition: keras_model.h:53
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::string findFile(const char *fileName) const
std::vector< float > Run(std::vector< std::vector< float >> const &inp2d) const override
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
parameter set interface
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
virtual void set_data(std::vector< std::vector< std::vector< float >>> const &d)
Definition: keras_model.h:70
std::unique_ptr< keras::KerasModel > m
PointIdAlgKeras(const fhicl::ParameterSet &pset)
map< int, array< map< int, double >, 2 >> Table
Definition: plot.C:18