LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
nnet::KerasModelInterface Class Reference

#include "PointIdAlg.h"

Inheritance diagram for nnet::KerasModelInterface:
nnet::ModelInterface

Public Member Functions

 KerasModelInterface (const char *modelFileName)
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) override
 
virtual std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1)
 

Protected Member Functions

std::string findFile (const char *fileName) const
 

Private Attributes

keras::KerasModel m
 

Detailed Description

Definition at line 74 of file PointIdAlg.h.

Constructor & Destructor Documentation

nnet::KerasModelInterface::KerasModelInterface ( const char *  modelFileName)

Definition at line 85 of file PointIdAlg.cxx.

86  : m(nnet::ModelInterface::findFile(modelFileName).c_str())
87 {
88  mf::LogInfo("KerasModelInterface") << "Keras model loaded.";
89 }
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
keras::KerasModel m
Definition: PointIdAlg.h:81
std::string findFile(const char *fileName) const
Definition: PointIdAlg.cxx:67

Member Function Documentation

std::string nnet::ModelInterface::findFile ( const char *  fileName) const
protectedinherited

Definition at line 67 of file PointIdAlg.cxx.

References art::errors::NotFound.

Referenced by nnet::TfModelInterface::TfModelInterface().

68 {
69  std::string fname_out;
70  cet::search_path sp("FW_SEARCH_PATH");
71  if (!sp.find_file(fileName, fname_out)) {
72  struct stat buffer;
73  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
74  else {
75  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
76  }
77  }
78  return fname_out;
79 }
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::vector< std::vector< float > > nnet::ModelInterface::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
)
virtualinherited

Reimplemented in nnet::TfModelInterface.

Definition at line 51 of file PointIdAlg.cxx.

References nnet::ModelInterface::Run(), and lar::dump::vector().

54 {
55  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty())
56  return std::vector<std::vector<float>>();
57 
58  if ((samples == -1) || (samples > (int)inps.size())) { samples = inps.size(); }
59 
60  std::vector<std::vector<float>> results;
61  for (int i = 0; i < samples; ++i) {
62  results.push_back(Run(inps[i]));
63  }
64  return results;
65 }
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
virtual std::vector< float > Run(std::vector< std::vector< float >> const &inp2d)=0
std::vector< float > nnet::KerasModelInterface::Run ( std::vector< std::vector< float >> const &  inp2d)
overridevirtual

Implements nnet::ModelInterface.

Definition at line 92 of file PointIdAlg.cxx.

References keras::KerasModel::compute_output(), m, and keras::DataChunk::set_data().

93 {
94  std::vector<std::vector<std::vector<float>>> inp3d;
95  inp3d.push_back(inp2d); // lots of copy, should add 2D to keras...
96 
97  keras::DataChunk* sample = new keras::DataChunk2D();
98  sample->set_data(inp3d); // and more copy...
99  std::vector<float> out = m.compute_output(sample);
100  delete sample;
101  return out;
102 }
virtual void set_data(std::vector< std::vector< std::vector< float >>> const &)
Definition: keras_model.h:53
std::vector< float > compute_output(keras::DataChunk *dc)
Definition: keras_model.cc:421
keras::KerasModel m
Definition: PointIdAlg.h:81

Member Data Documentation

keras::KerasModel nnet::KerasModelInterface::m
private

Definition at line 81 of file PointIdAlg.h.

Referenced by Run().


The documentation for this class was generated from the following files: