LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
wframerec_tool::WireframeRecogTf Class Reference
Inheritance diagram for wframerec_tool::WireframeRecogTf:

Public Member Functions

 WireframeRecogTf (const fhicl::ParameterSet &pset)
 
std::vector< std::vector< float > > predictWireframeType (const std::vector< std::vector< std::vector< short >>> &) const override
 

Private Attributes

std::unique_ptr< tf::Graphg
 
std::string fNNetModelFilePath
 
std::vector< std::string > fNNetOutputPattern
 
bool fUseBundle
 

Detailed Description

Definition at line 11 of file WireframeRecogTf_tool.cc.

Constructor & Destructor Documentation

wframerec_tool::WireframeRecogTf::WireframeRecogTf ( const fhicl::ParameterSet pset)
explicit

Definition at line 26 of file WireframeRecogTf_tool.cc.

References tf::Graph::create(), fNNetModelFilePath, fNNetOutputPattern, fUseBundle, g, fhicl::ParameterSet::get(), and art::errors::Unknown.

27  {
28  fNNetModelFilePath = pset.get<std::string>("NNetModelFile", "mymodel.pb");
29  fUseBundle = pset.get<bool>("UseSavedModelBundle", false);
31  pset.get<std::vector<std::string>>("NNetOutputPattern", {"cnn_output", "dense_3"});
32  if ((fNNetModelFilePath.length() > 3) &&
33  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0) &&
34  !fUseBundle) {
36  findFile(fNNetModelFilePath.c_str()).c_str(), fNNetOutputPattern, fUseBundle);
37  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
38  mf::LogInfo("WireframeRecogTf") << "TF model loaded.";
39  }
40  else if ((fNNetModelFilePath.length() > 3) && fUseBundle) {
42  findFile(fNNetModelFilePath.c_str()).c_str(), fNNetOutputPattern, fUseBundle);
43  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
44  mf::LogInfo("WireframeRecogTf") << "TF model loaded.";
45  }
46  else {
47  mf::LogError("WireframeRecogTf") << "File name extension not supported.";
48  }
49 
50  setupWframeRecRoiParams(pset);
51  }
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:29
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::unique_ptr< tf::Graph > g
T get(std::string const &key) const
Definition: ParameterSet.h:314
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::vector< std::string > fNNetOutputPattern

Member Function Documentation

std::vector< std::vector< float > > wframerec_tool::WireframeRecogTf::predictWireframeType ( const std::vector< std::vector< std::vector< short >>> &  wireframes) const
override

Definition at line 54 of file WireframeRecogTf_tool.cc.

References DEFINE_ART_CLASS_TOOL, g, and r.

56  {
57  if (wireframes.empty() || wireframes.front().empty() || wireframes.front().front().empty()) {
58  return std::vector<std::vector<float>>();
59  //return std::vector<std::vector<float>>(samples,std::vector<float>(2,0.));
60  }
61 
62  long long int samples = wireframes.size(), rows = wireframes.front().size(),
63  cols = wireframes.front().front().size();
64  //std::cout << " !!!! predictWireframeType: samples = " << samples << ", rows = " << rows << ", cols = " << cols << std::endl;
65 
66  std::vector<tensorflow::Tensor> _x;
67  _x.push_back(
68  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1})));
69  auto input_map = _x[0].tensor<float, 4>();
70  for (long long int s = 0; s < samples; ++s) {
71  const auto& wframe = wireframes[s];
72  for (long long int r = 0; r < rows; ++r) {
73  const auto& row = wframe[r];
74  for (long long int c = 0; c < cols; ++c) {
75  input_map(s, r, c, 0) = float(row[c]);
76  }
77  }
78  }
79  return g->runx(_x);
80  }
TRandom r
Definition: spectrum.C:23
std::unique_ptr< tf::Graph > g

Member Data Documentation

std::string wframerec_tool::WireframeRecogTf::fNNetModelFilePath
private

Definition at line 20 of file WireframeRecogTf_tool.cc.

Referenced by WireframeRecogTf().

std::vector<std::string> wframerec_tool::WireframeRecogTf::fNNetOutputPattern
private

Definition at line 21 of file WireframeRecogTf_tool.cc.

Referenced by WireframeRecogTf().

bool wframerec_tool::WireframeRecogTf::fUseBundle
private

Definition at line 22 of file WireframeRecogTf_tool.cc.

Referenced by WireframeRecogTf().

std::unique_ptr<tf::Graph> wframerec_tool::WireframeRecogTf::g
private

Definition at line 19 of file WireframeRecogTf_tool.cc.

Referenced by predictWireframeType(), and WireframeRecogTf().


The documentation for this class was generated from the following file: