LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
WireframeRecogTf_tool.cc
Go to the documentation of this file.
3 #include "larrecodnn/ImagePatternAlgs/ToolInterfaces/IWireframeRecog.h"
5 #include "tensorflow/core/public/session.h"
6 
7 #include <sys/stat.h>
8 
9 namespace wframerec_tool {
10 
11  class WireframeRecogTf : public IWireframeRecog {
12  public:
13  explicit WireframeRecogTf(const fhicl::ParameterSet& pset);
14 
15  std::vector<std::vector<float>> predictWireframeType(
16  const std::vector<std::vector<std::vector<short>>>&) const override;
17 
18  private:
19  std::unique_ptr<tf::Graph> g; // network graph
20  std::string fNNetModelFilePath;
21  std::vector<std::string> fNNetOutputPattern;
22  bool fUseBundle;
23  };
24 
25  // ------------------------------------------------------
27  {
28  fNNetModelFilePath = pset.get<std::string>("NNetModelFile", "mymodel.pb");
29  fUseBundle = pset.get<bool>("UseSavedModelBundle", false);
31  pset.get<std::vector<std::string>>("NNetOutputPattern", {"cnn_output", "dense_3"});
32  if ((fNNetModelFilePath.length() > 3) &&
33  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0) &&
34  !fUseBundle) {
36  findFile(fNNetModelFilePath.c_str()).c_str(), fNNetOutputPattern, fUseBundle);
37  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
38  mf::LogInfo("WireframeRecogTf") << "TF model loaded.";
39  }
40  else if ((fNNetModelFilePath.length() > 3) && fUseBundle) {
42  findFile(fNNetModelFilePath.c_str()).c_str(), fNNetOutputPattern, fUseBundle);
43  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
44  mf::LogInfo("WireframeRecogTf") << "TF model loaded.";
45  }
46  else {
47  mf::LogError("WireframeRecogTf") << "File name extension not supported.";
48  }
49 
50  setupWframeRecRoiParams(pset);
51  }
52 
53  // ------------------------------------------------------
54  std::vector<std::vector<float>> WireframeRecogTf::predictWireframeType(
55  const std::vector<std::vector<std::vector<short>>>& wireframes) const
56  {
57  if (wireframes.empty() || wireframes.front().empty() || wireframes.front().front().empty()) {
58  return std::vector<std::vector<float>>();
59  //return std::vector<std::vector<float>>(samples,std::vector<float>(2,0.));
60  }
61 
62  long long int samples = wireframes.size(), rows = wireframes.front().size(),
63  cols = wireframes.front().front().size();
64  //std::cout << " !!!! predictWireframeType: samples = " << samples << ", rows = " << rows << ", cols = " << cols << std::endl;
65 
66  std::vector<tensorflow::Tensor> _x;
67  _x.push_back(
68  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1})));
69  auto input_map = _x[0].tensor<float, 4>();
70  for (long long int s = 0; s < samples; ++s) {
71  const auto& wframe = wireframes[s];
72  for (long long int r = 0; r < rows; ++r) {
73  const auto& row = wframe[r];
74  for (long long int c = 0; c < cols; ++c) {
75  input_map(s, r, c, 0) = float(row[c]);
76  }
77  }
78  }
79  return g->runx(_x);
80  }
81 
82 }
TRandom r
Definition: spectrum.C:23
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:29
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::vector< std::vector< float > > predictWireframeType(const std::vector< std::vector< std::vector< short >>> &) const override
std::unique_ptr< tf::Graph > g
WireframeRecogTf(const fhicl::ParameterSet &pset)
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
T get(std::string const &key) const
Definition: ParameterSet.h:314
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::vector< std::string > fNNetOutputPattern