LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
PointIdAlgTf_tool.cc
Go to the documentation of this file.
1 // Class: PointIdAlgTf_tool (tool version of TensorFlow model interface in PointIdAlg)
3 // Authors: D.Stefan (Dorota.Stefan@ncbj.gov.pl), from DUNE, CERN/NCBJ, since May 2016
4 // R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, since May 2016
5 // P.Plonski, from DUNE, WUT, since May 2016
6 // D.Smith, from LArIAT, BU, 2017: real data dump
7 // M.Wang, from DUNE, FNAL, 2020: tool version
9 
11 
13 #include "larrecodnn/ImagePatternAlgs/ToolInterfaces/IPointIdAlg.h"
14 #include "tensorflow/core/public/session.h"
15 
16 #include <sys/stat.h>
17 
18 namespace PointIdAlgTools {
19 
20  class PointIdAlgTf : public IPointIdAlg {
21  public:
22  explicit PointIdAlgTf(fhicl::Table<Config> const& table);
23 
24  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) const override;
25  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
26  int samples = -1) const override;
27 
28  protected:
29  std::string findFile(const char* fileName) const;
30 
31  private:
32  std::unique_ptr<tf::Graph> g; // network graph
33  std::vector<std::string> fNNetOutputPattern;
34  std::string fNNetModelFilePath;
35  };
36 
37  // ------------------------------------------------------
38  PointIdAlgTf::PointIdAlgTf(fhicl::Table<Config> const& table) : img::DataProviderAlg(table())
39  {
40  // ... Get common config vars
41  fNNetOutputs = table().NNetOutputs();
42  fPatchSizeW = table().PatchSizeW();
43  fPatchSizeD = table().PatchSizeD();
44  fCurrentWireIdx = 99999;
45  fCurrentScaledDrift = 99999;
46 
47  // ... Get "optional" config vars specific to tf interface
48  std::string s_cfgvr;
49  if (table().NNetModelFile(s_cfgvr)) { fNNetModelFilePath = s_cfgvr; }
50  else {
51  fNNetModelFilePath = "mycnn";
52  }
53  std::vector<std::string> vs_cfgvr;
54  if (table().NNetOutputPattern(vs_cfgvr)) { fNNetOutputPattern = vs_cfgvr; }
55  else {
56  fNNetOutputPattern = {"cnn_output", "_netout"};
57  }
58 
59  if ((fNNetModelFilePath.length() > 3) &&
60  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0)) {
62  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
63  mf::LogInfo("PointIdAlgTf") << "TF model loaded.";
64  }
65  else {
66  mf::LogError("PointIdAlgTf") << "File name extension not supported.";
67  }
68 
69  resizePatch();
70  }
71 
72  // ------------------------------------------------------
73  std::string PointIdAlgTf::findFile(const char* fileName) const
74  {
75  std::string fname_out;
76  cet::search_path sp("FW_SEARCH_PATH");
77  if (!sp.find_file(fileName, fname_out)) {
78  struct stat buffer;
79  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
80  else {
81  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
82  }
83  }
84  return fname_out;
85  }
86 
87  // ------------------------------------------------------
88  std::vector<float> PointIdAlgTf::Run(std::vector<std::vector<float>> const& inp2d) const
89  {
90  long long int rows = inp2d.size(), cols = inp2d.front().size();
91 
92  std::vector<tensorflow::Tensor> _x;
93  _x.push_back(
94  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, rows, cols, 1})));
95  auto input_map = _x[0].tensor<float, 4>();
96  for (long long int r = 0; r < rows; ++r) {
97  const auto& row = inp2d[r];
98  for (long long int c = 0; c < cols; ++c) {
99  input_map(0, r, c, 0) = row[c];
100  }
101  }
102 
103  auto out = g->runx(_x);
104  if (!out.empty())
105  return out.front();
106  else
107  return std::vector<float>();
108  }
109 
110  // ------------------------------------------------------
111  std::vector<std::vector<float>> PointIdAlgTf::Run(
112  std::vector<std::vector<std::vector<float>>> const& inps,
113  int samples) const
114  {
115 
116  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
117  return std::vector<std::vector<float>>();
118  }
119 
120  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
121 
122  long long int rows = inps.front().size(), cols = inps.front().front().size();
123 
124  std::vector<tensorflow::Tensor> _x;
125  _x.push_back(
126  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1})));
127  auto input_map = _x[0].tensor<float, 4>();
128  for (long long int s = 0; s < samples; ++s) {
129  const auto& sample = inps[s];
130  for (long long int r = 0; r < rows; ++r) {
131  const auto& row = sample[r];
132  for (long long int c = 0; c < cols; ++c) {
133  input_map(s, r, c, 0) = row[c];
134  }
135  }
136  }
137  return g->runx(_x);
138  }
139 
140 }
TRandom r
Definition: spectrum.C:23
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:29
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
std::string findFile(const char *fileName) const
std::vector< std::string > fNNetOutputPattern
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
std::vector< float > Run(std::vector< std::vector< float >> const &inp2d) const override
PointIdAlgTf(fhicl::Table< Config > const &table)
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::unique_ptr< tf::Graph > g