LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
wavrec_tool::WaveformRecogTf Class Reference
Inheritance diagram for wavrec_tool::WaveformRecogTf:

Public Member Functions

 WaveformRecogTf (const fhicl::ParameterSet &pset)
 
std::vector< std::vector< float > > predictWaveformType (const std::vector< std::vector< float >> &) const override
 

Private Attributes

std::unique_ptr< tf::Graphg
 
std::string fNNetModelFilePath
 
std::vector< std::string > fNNetOutputPattern
 
bool fUseBundle
 

Detailed Description

Definition at line 11 of file WaveformRecogTf_tool.cc.

Constructor & Destructor Documentation

wavrec_tool::WaveformRecogTf::WaveformRecogTf ( const fhicl::ParameterSet pset)
explicit

Definition at line 26 of file WaveformRecogTf_tool.cc.

References tf::Graph::create(), fNNetModelFilePath, fNNetOutputPattern, fUseBundle, g, fhicl::ParameterSet::get(), and art::errors::Unknown.

27  {
28  fNNetModelFilePath = pset.get<std::string>("NNetModelFile", "mymodel.pb");
29  fUseBundle = pset.get<bool>("UseSavedModelBundle", false);
31  pset.get<std::vector<std::string>>("NNetOutputPattern", {"cnn_output", "dense_3"});
32  if ((fNNetModelFilePath.length() > 3) &&
33  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0) &&
34  !fUseBundle) {
36  findFile(fNNetModelFilePath.c_str()).c_str(), fNNetOutputPattern, fUseBundle);
37  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
38  mf::LogInfo("WaveformRecogTf") << "TF model loaded.";
39  }
40  else if ((fNNetModelFilePath.length() > 3) && fUseBundle) {
42  findFile(fNNetModelFilePath.c_str()).c_str(), fNNetOutputPattern, fUseBundle);
43  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
44  mf::LogInfo("WaveformRecogTf") << "TF model loaded.";
45  }
46  else {
47  mf::LogError("WaveformRecogTf") << "File name extension not supported.";
48  }
49 
50  setupWaveRecRoiParams(pset);
51  }
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:29
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
T get(std::string const &key) const
Definition: ParameterSet.h:314
std::vector< std::string > fNNetOutputPattern
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::unique_ptr< tf::Graph > g

Member Function Documentation

std::vector< std::vector< float > > wavrec_tool::WaveformRecogTf::predictWaveformType ( const std::vector< std::vector< float >> &  waveforms) const
override

Definition at line 54 of file WaveformRecogTf_tool.cc.

References DEFINE_ART_CLASS_TOOL, and g.

56  {
57  if (waveforms.empty() || waveforms.front().empty()) {
58  return std::vector<std::vector<float>>();
59  }
60 
61  long long int samples = waveforms.size(), numtcks = waveforms.front().size();
62 
63  //std::cout<<"Samples: "<<samples<<", Ticks: "<<numtcks<<std::endl;
64  std::vector<tensorflow::Tensor> _x;
65  _x.push_back(
66  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, numtcks, 1})));
67  auto input_map = _x[0].tensor<float, 3>();
68  for (long long int s = 0; s < samples; ++s) {
69  const auto& wvfrm = waveforms[s];
70  for (long long int t = 0; t < numtcks; ++t) {
71  input_map(s, t, 0) = wvfrm[t];
72  }
73  }
74 
75  return g->runx(_x);
76  }
std::unique_ptr< tf::Graph > g

Member Data Documentation

std::string wavrec_tool::WaveformRecogTf::fNNetModelFilePath
private

Definition at line 20 of file WaveformRecogTf_tool.cc.

Referenced by WaveformRecogTf().

std::vector<std::string> wavrec_tool::WaveformRecogTf::fNNetOutputPattern
private

Definition at line 21 of file WaveformRecogTf_tool.cc.

Referenced by WaveformRecogTf().

bool wavrec_tool::WaveformRecogTf::fUseBundle
private

Definition at line 22 of file WaveformRecogTf_tool.cc.

Referenced by WaveformRecogTf().

std::unique_ptr<tf::Graph> wavrec_tool::WaveformRecogTf::g
private

Definition at line 19 of file WaveformRecogTf_tool.cc.

Referenced by predictWaveformType(), and WaveformRecogTf().


The documentation for this class was generated from the following file: