LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
PointIdAlgTools::PointIdAlgTf Class Reference
Inheritance diagram for PointIdAlgTools::PointIdAlgTf:

Public Member Functions

 PointIdAlgTf (fhicl::Table< Config > const &table)
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) const override
 
std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) const override
 

Protected Member Functions

std::string findFile (const char *fileName) const
 

Private Attributes

std::unique_ptr< tf::Graphg
 
std::vector< std::string > fNNetOutputPattern
 
std::string fNNetModelFilePath
 

Detailed Description

Definition at line 20 of file PointIdAlgTf_tool.cc.

Constructor & Destructor Documentation

PointIdAlgTools::PointIdAlgTf::PointIdAlgTf ( fhicl::Table< Config > const &  table)
explicit

Definition at line 38 of file PointIdAlgTf_tool.cc.

References tf::Graph::create(), findFile(), fNNetModelFilePath, fNNetOutputPattern, g, and art::errors::Unknown.

38  : img::DataProviderAlg(table())
39  {
40  // ... Get common config vars
41  fNNetOutputs = table().NNetOutputs();
42  fPatchSizeW = table().PatchSizeW();
43  fPatchSizeD = table().PatchSizeD();
44  fCurrentWireIdx = 99999;
45  fCurrentScaledDrift = 99999;
46 
47  // ... Get "optional" config vars specific to tf interface
48  std::string s_cfgvr;
49  if (table().NNetModelFile(s_cfgvr)) { fNNetModelFilePath = s_cfgvr; }
50  else {
51  fNNetModelFilePath = "mycnn";
52  }
53  std::vector<std::string> vs_cfgvr;
54  if (table().NNetOutputPattern(vs_cfgvr)) { fNNetOutputPattern = vs_cfgvr; }
55  else {
56  fNNetOutputPattern = {"cnn_output", "_netout"};
57  }
58 
59  if ((fNNetModelFilePath.length() > 3) &&
60  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0)) {
62  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
63  mf::LogInfo("PointIdAlgTf") << "TF model loaded.";
64  }
65  else {
66  mf::LogError("PointIdAlgTf") << "File name extension not supported.";
67  }
68 
69  resizePatch();
70  }
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:29
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
std::string findFile(const char *fileName) const
std::vector< std::string > fNNetOutputPattern
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::unique_ptr< tf::Graph > g

Member Function Documentation

std::string PointIdAlgTools::PointIdAlgTf::findFile ( const char *  fileName) const
protected

Definition at line 73 of file PointIdAlgTf_tool.cc.

References art::errors::NotFound.

Referenced by PointIdAlgTf().

74  {
75  std::string fname_out;
76  cet::search_path sp("FW_SEARCH_PATH");
77  if (!sp.find_file(fileName, fname_out)) {
78  struct stat buffer;
79  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
80  else {
81  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
82  }
83  }
84  return fname_out;
85  }
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::vector< float > PointIdAlgTools::PointIdAlgTf::Run ( std::vector< std::vector< float >> const &  inp2d) const
override

Definition at line 88 of file PointIdAlgTf_tool.cc.

References g, and r.

89  {
90  long long int rows = inp2d.size(), cols = inp2d.front().size();
91 
92  std::vector<tensorflow::Tensor> _x;
93  _x.push_back(
94  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, rows, cols, 1})));
95  auto input_map = _x[0].tensor<float, 4>();
96  for (long long int r = 0; r < rows; ++r) {
97  const auto& row = inp2d[r];
98  for (long long int c = 0; c < cols; ++c) {
99  input_map(0, r, c, 0) = row[c];
100  }
101  }
102 
103  auto out = g->runx(_x);
104  if (!out.empty())
105  return out.front();
106  else
107  return std::vector<float>();
108  }
TRandom r
Definition: spectrum.C:23
std::unique_ptr< tf::Graph > g
std::vector< std::vector< float > > PointIdAlgTools::PointIdAlgTf::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
) const
override

Definition at line 111 of file PointIdAlgTf_tool.cc.

References DEFINE_ART_CLASS_TOOL, g, and r.

114  {
115 
116  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
117  return std::vector<std::vector<float>>();
118  }
119 
120  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
121 
122  long long int rows = inps.front().size(), cols = inps.front().front().size();
123 
124  std::vector<tensorflow::Tensor> _x;
125  _x.push_back(
126  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1})));
127  auto input_map = _x[0].tensor<float, 4>();
128  for (long long int s = 0; s < samples; ++s) {
129  const auto& sample = inps[s];
130  for (long long int r = 0; r < rows; ++r) {
131  const auto& row = sample[r];
132  for (long long int c = 0; c < cols; ++c) {
133  input_map(s, r, c, 0) = row[c];
134  }
135  }
136  }
137  return g->runx(_x);
138  }
TRandom r
Definition: spectrum.C:23
std::unique_ptr< tf::Graph > g

Member Data Documentation

std::string PointIdAlgTools::PointIdAlgTf::fNNetModelFilePath
private

Definition at line 34 of file PointIdAlgTf_tool.cc.

Referenced by PointIdAlgTf().

std::vector<std::string> PointIdAlgTools::PointIdAlgTf::fNNetOutputPattern
private

Definition at line 33 of file PointIdAlgTf_tool.cc.

Referenced by PointIdAlgTf().

std::unique_ptr<tf::Graph> PointIdAlgTools::PointIdAlgTf::g
private

Definition at line 32 of file PointIdAlgTf_tool.cc.

Referenced by PointIdAlgTf(), and Run().


The documentation for this class was generated from the following file: