#include "PointIdAlg.h"
|
| TfModelInterface (const char *modelFileName) |
|
std::vector< std::vector< float > > | Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) override |
|
std::vector< float > | Run (std::vector< std::vector< float >> const &inp2d) override |
|
|
std::string | findFile (const char *fileName) const |
|
Definition at line 85 of file PointIdAlg.h.
nnet::TfModelInterface::TfModelInterface |
( |
const char * |
modelFileName | ) |
|
Definition at line 108 of file PointIdAlg.cxx.
References tf::Graph::create(), nnet::ModelInterface::findFile(), and art::errors::Unknown.
111 {
"cnn_output",
"_netout"});
114 mf::LogInfo(
"TfModelInterface") <<
"TF model loaded.";
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
std::string findFile(const char *fileName) const
std::unique_ptr< tf::Graph > g
std::string nnet::ModelInterface::findFile |
( |
const char * |
fileName | ) |
const |
|
protectedinherited |
Definition at line 67 of file PointIdAlg.cxx.
References art::errors::NotFound.
Referenced by TfModelInterface().
69 std::string fname_out;
70 cet::search_path sp(
"FW_SEARCH_PATH");
71 if (!sp.find_file(fileName, fname_out)) {
73 if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
std::vector< std::vector< float > > nnet::TfModelInterface::Run |
( |
std::vector< std::vector< std::vector< float >>> const & |
inps, |
|
|
int |
samples = -1 |
|
) |
| |
|
overridevirtual |
Reimplemented from nnet::ModelInterface.
Definition at line 118 of file PointIdAlg.cxx.
References r, and lar::dump::vector().
122 if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty())
125 if ((samples == -1) || (samples > (
long long int)inps.size())) { samples = inps.size(); }
127 long long int rows = inps.front().size(), cols = inps.front().front().size();
129 std::vector<tensorflow::Tensor> _x;
131 tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1})));
132 auto input_map = _x[0].tensor<float, 4>();
133 for (
long long int s = 0; s < samples; ++s) {
134 const auto& sample = inps[s];
135 for (
long long int r = 0;
r < rows; ++
r) {
136 const auto& row = sample[
r];
137 for (
long long int c = 0; c < cols; ++c) {
138 input_map(s,
r, c, 0) = row[c];
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
std::unique_ptr< tf::Graph > g
std::vector< float > nnet::TfModelInterface::Run |
( |
std::vector< std::vector< float >> const & |
inp2d | ) |
|
|
overridevirtual |
Implements nnet::ModelInterface.
Definition at line 147 of file PointIdAlg.cxx.
References r.
149 long long int rows = inp2d.size(), cols = inp2d.front().size();
151 std::vector<tensorflow::Tensor> _x;
153 tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, rows, cols, 1})));
154 auto input_map = _x[0].tensor<float, 4>();
155 for (
long long int r = 0;
r < rows; ++
r) {
156 const auto& row = inp2d[
r];
157 for (
long long int c = 0; c < cols; ++c) {
158 input_map(0,
r, c, 0) = row[c];
162 auto out =
g->runx(_x);
166 return std::vector<float>();
std::unique_ptr< tf::Graph > g
std::unique_ptr<tf::Graph> nnet::TfModelInterface::g |
|
private |
The documentation for this class was generated from the following files:
- larrecodnn/v09_23_00/source/larrecodnn/ImagePatternAlgs/Tensorflow/PointIdAlg/PointIdAlg.h
- larrecodnn/v09_23_00/source/larrecodnn/ImagePatternAlgs/Tensorflow/PointIdAlg/PointIdAlg.cxx