LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
PointIdAlgTools::PointIdAlgKeras Class Reference
Inheritance diagram for PointIdAlgTools::PointIdAlgKeras:

Public Member Functions

 PointIdAlgKeras (const fhicl::ParameterSet &pset)
 
 PointIdAlgKeras (const Config &config)
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) const override
 
std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) const override
 

Private Member Functions

std::string findFile (const char *fileName) const
 

Private Attributes

std::unique_ptr< keras::KerasModelm
 
std::string fNNetModelFilePath
 

Detailed Description

Definition at line 19 of file PointIdAlgKeras_tool.cc.

Constructor & Destructor Documentation

PointIdAlgTools::PointIdAlgKeras::PointIdAlgKeras ( const fhicl::ParameterSet pset)
inlineexplicit

Definition at line 21 of file PointIdAlgKeras_tool.cc.

References Run(), and lar::dump::vector().

23  {}
PointIdAlgKeras(const fhicl::ParameterSet &pset)
PointIdAlgTools::PointIdAlgKeras::PointIdAlgKeras ( const Config &  config)
explicit

Definition at line 37 of file PointIdAlgKeras_tool.cc.

References findFile(), fNNetModelFilePath, and m.

37  : img::DataProviderAlg(config)
38  {
39  // ... Get common config vars
40  fNNetOutputs = config.NNetOutputs();
41  fPatchSizeW = config.PatchSizeW();
42  fPatchSizeD = config.PatchSizeD();
43  fCurrentWireIdx = 99999;
44  fCurrentScaledDrift = 99999;
45 
46  // ... Get "optional" config vars specific to tf interface
47  std::string s_cfgvr;
48  if (config.NNetModelFile(s_cfgvr)) { fNNetModelFilePath = s_cfgvr; }
49  else {
50  fNNetModelFilePath = "mycnn";
51  }
52 
53  if ((fNNetModelFilePath.length() > 5) &&
54  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 5, 5, ".nnet") == 0)) {
55  m = std::make_unique<keras::KerasModel>(findFile(fNNetModelFilePath.c_str()).c_str());
56  mf::LogInfo("PointIdAlgKeras") << "Keras model loaded.";
57  }
58  else {
59  mf::LogError("PointIdAlgKeras") << "File name extension not supported.";
60  }
61 
62  resizePatch();
63  }
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::string findFile(const char *fileName) const
std::unique_ptr< keras::KerasModel > m

Member Function Documentation

std::string PointIdAlgTools::PointIdAlgKeras::findFile ( const char *  fileName) const
private

Definition at line 66 of file PointIdAlgKeras_tool.cc.

References art::errors::NotFound.

Referenced by PointIdAlgKeras().

67  {
68  std::string fname_out;
69  cet::search_path sp("FW_SEARCH_PATH");
70  if (!sp.find_file(fileName, fname_out)) {
71  struct stat buffer;
72  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
73  else {
74  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
75  }
76  }
77  return fname_out;
78  }
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::vector< float > PointIdAlgTools::PointIdAlgKeras::Run ( std::vector< std::vector< float >> const &  inp2d) const
override

Definition at line 81 of file PointIdAlgKeras_tool.cc.

References m, and keras::DataChunk2D::set_data().

Referenced by PointIdAlgKeras().

82  {
83  std::vector<std::vector<std::vector<float>>> inp3d;
84  inp3d.push_back(inp2d); // lots of copy, should add 2D to keras...
85 
86  keras::DataChunk2D sample;
87  sample.set_data(inp3d);
88  return m->compute_output(&sample);
89  }
virtual void set_data(std::vector< std::vector< std::vector< float >>> const &d)
Definition: keras_model.h:70
std::unique_ptr< keras::KerasModel > m
std::vector< std::vector< float > > PointIdAlgTools::PointIdAlgKeras::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
) const
override

Definition at line 92 of file PointIdAlgKeras_tool.cc.

References DEFINE_ART_CLASS_TOOL, m, and keras::DataChunk::set_data().

95  {
96 
97  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
98  return std::vector<std::vector<float>>();
99  }
100 
101  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
102 
103  std::vector<std::vector<float>> out;
104 
105  for (long long int s = 0; s < samples; ++s) {
106  std::vector<std::vector<std::vector<float>>> inp3d;
107  inp3d.push_back(inps[s]); // lots of copy, should add 2D to keras...
108 
109  keras::DataChunk* sample = new keras::DataChunk2D();
110  sample->set_data(inp3d); // and more copy...
111  out.push_back(m->compute_output(sample));
112  delete sample;
113  }
114 
115  return out;
116  }
virtual void set_data(std::vector< std::vector< std::vector< float >>> const &)
Definition: keras_model.h:53
std::unique_ptr< keras::KerasModel > m

Member Data Documentation

std::string PointIdAlgTools::PointIdAlgKeras::fNNetModelFilePath
private

Definition at line 32 of file PointIdAlgKeras_tool.cc.

Referenced by PointIdAlgKeras().

std::unique_ptr<keras::KerasModel> PointIdAlgTools::PointIdAlgKeras::m
private

Definition at line 31 of file PointIdAlgKeras_tool.cc.

Referenced by PointIdAlgKeras(), and Run().


The documentation for this class was generated from the following file: