LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
WaveformRecogTf_tool.cc
Go to the documentation of this file.
3 #include "larrecodnn/ImagePatternAlgs/ToolInterfaces/IWaveformRecog.h"
5 #include "tensorflow/core/public/session.h"
6 
7 #include <sys/stat.h>
8 
9 namespace wavrec_tool {
10 
11  class WaveformRecogTf : public IWaveformRecog {
12  public:
13  explicit WaveformRecogTf(const fhicl::ParameterSet& pset);
14 
15  std::vector<std::vector<float>> predictWaveformType(
16  const std::vector<std::vector<float>>&) const override;
17 
18  private:
19  std::unique_ptr<tf::Graph> g; // network graph
20  std::string fNNetModelFilePath;
21  std::vector<std::string> fNNetOutputPattern;
22  bool fUseBundle;
23  };
24 
25  // ------------------------------------------------------
27  {
28  fNNetModelFilePath = pset.get<std::string>("NNetModelFile", "mymodel.pb");
29  fUseBundle = pset.get<bool>("UseSavedModelBundle", false);
31  pset.get<std::vector<std::string>>("NNetOutputPattern", {"cnn_output", "dense_3"});
32  if ((fNNetModelFilePath.length() > 3) &&
33  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0) &&
34  !fUseBundle) {
36  findFile(fNNetModelFilePath.c_str()).c_str(), fNNetOutputPattern, fUseBundle);
37  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
38  mf::LogInfo("WaveformRecogTf") << "TF model loaded.";
39  }
40  else if ((fNNetModelFilePath.length() > 3) && fUseBundle) {
42  findFile(fNNetModelFilePath.c_str()).c_str(), fNNetOutputPattern, fUseBundle);
43  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
44  mf::LogInfo("WaveformRecogTf") << "TF model loaded.";
45  }
46  else {
47  mf::LogError("WaveformRecogTf") << "File name extension not supported.";
48  }
49 
50  setupWaveRecRoiParams(pset);
51  }
52 
53  // ------------------------------------------------------
54  std::vector<std::vector<float>> WaveformRecogTf::predictWaveformType(
55  const std::vector<std::vector<float>>& waveforms) const
56  {
57  if (waveforms.empty() || waveforms.front().empty()) {
58  return std::vector<std::vector<float>>();
59  }
60 
61  long long int samples = waveforms.size(), numtcks = waveforms.front().size();
62 
63  //std::cout<<"Samples: "<<samples<<", Ticks: "<<numtcks<<std::endl;
64  std::vector<tensorflow::Tensor> _x;
65  _x.push_back(
66  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, numtcks, 1})));
67  auto input_map = _x[0].tensor<float, 3>();
68  for (long long int s = 0; s < samples; ++s) {
69  const auto& wvfrm = waveforms[s];
70  for (long long int t = 0; t < numtcks; ++t) {
71  input_map(s, t, 0) = wvfrm[t];
72  }
73  }
74 
75  return g->runx(_x);
76  }
77 
78 }
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:29
WaveformRecogTf(const fhicl::ParameterSet &pset)
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
T get(std::string const &key) const
Definition: ParameterSet.h:314
std::vector< std::vector< float > > predictWaveformType(const std::vector< std::vector< float >> &) const override
std::vector< std::string > fNNetOutputPattern
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::unique_ptr< tf::Graph > g