LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
nnet::TfModelInterface Class Reference

#include "PointIdAlg.h"

Inheritance diagram for nnet::TfModelInterface:
nnet::ModelInterface

Public Member Functions

 TfModelInterface (const char *modelFileName)
 
std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) override
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) override
 

Protected Member Functions

std::string findFile (const char *fileName) const
 

Private Attributes

std::unique_ptr< tf::Graphg
 

Detailed Description

Definition at line 85 of file PointIdAlg.h.

Constructor & Destructor Documentation

nnet::TfModelInterface::TfModelInterface ( const char *  modelFileName)

Definition at line 108 of file PointIdAlg.cxx.

References tf::Graph::create(), nnet::ModelInterface::findFile(), and art::errors::Unknown.

109 {
110  g = tf::Graph::create(nnet::ModelInterface::findFile(modelFileName).c_str(),
111  {"cnn_output", "_netout"});
112  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
113 
114  mf::LogInfo("TfModelInterface") << "TF model loaded.";
115 }
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:29
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::string findFile(const char *fileName) const
Definition: PointIdAlg.cxx:67
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:94

Member Function Documentation

std::string nnet::ModelInterface::findFile ( const char *  fileName) const
protectedinherited

Definition at line 67 of file PointIdAlg.cxx.

References art::errors::NotFound.

Referenced by TfModelInterface().

68 {
69  std::string fname_out;
70  cet::search_path sp("FW_SEARCH_PATH");
71  if (!sp.find_file(fileName, fname_out)) {
72  struct stat buffer;
73  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
74  else {
75  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
76  }
77  }
78  return fname_out;
79 }
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::vector< std::vector< float > > nnet::TfModelInterface::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
)
overridevirtual

Reimplemented from nnet::ModelInterface.

Definition at line 118 of file PointIdAlg.cxx.

References r, and lar::dump::vector().

121 {
122  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty())
123  return std::vector<std::vector<float>>();
124 
125  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
126 
127  long long int rows = inps.front().size(), cols = inps.front().front().size();
128 
129  std::vector<tensorflow::Tensor> _x;
130  _x.push_back(
131  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1})));
132  auto input_map = _x[0].tensor<float, 4>();
133  for (long long int s = 0; s < samples; ++s) {
134  const auto& sample = inps[s];
135  for (long long int r = 0; r < rows; ++r) {
136  const auto& row = sample[r];
137  for (long long int c = 0; c < cols; ++c) {
138  input_map(s, r, c, 0) = row[c];
139  }
140  }
141  }
142 
143  return g->runx(_x);
144 }
TRandom r
Definition: spectrum.C:23
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:94
std::vector< float > nnet::TfModelInterface::Run ( std::vector< std::vector< float >> const &  inp2d)
overridevirtual

Implements nnet::ModelInterface.

Definition at line 147 of file PointIdAlg.cxx.

References r.

148 {
149  long long int rows = inp2d.size(), cols = inp2d.front().size();
150 
151  std::vector<tensorflow::Tensor> _x;
152  _x.push_back(
153  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, rows, cols, 1})));
154  auto input_map = _x[0].tensor<float, 4>();
155  for (long long int r = 0; r < rows; ++r) {
156  const auto& row = inp2d[r];
157  for (long long int c = 0; c < cols; ++c) {
158  input_map(0, r, c, 0) = row[c];
159  }
160  }
161 
162  auto out = g->runx(_x);
163  if (!out.empty())
164  return out.front();
165  else
166  return std::vector<float>();
167 }
TRandom r
Definition: spectrum.C:23
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:94

Member Data Documentation

std::unique_ptr<tf::Graph> nnet::TfModelInterface::g
private

Definition at line 94 of file PointIdAlg.h.


The documentation for this class was generated from the following files: