LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
phot::TFLoaderMLP Class Reference

#include "TFLoaderMLP.h"

Inheritance diagram for phot::TFLoaderMLP:
phot::TFLoader

Public Member Functions

 TFLoaderMLP (fhicl::ParameterSet const &pset)
 
void Initialization ()
 
void CloseSession ()
 
void Predict (std::vector< double > pars)
 
std::vector< double > GetPrediction () const
 

Protected Attributes

std::vector< double > prediction
 

Private Attributes

std::string ModelName
 
std::vector< std::string > InputsName
 
std::string OutputName
 
tensorflow::SavedModelBundleLite * modelbundle
 
tensorflow::Status status
 

Detailed Description

Definition at line 16 of file TFLoaderMLP.h.

Constructor & Destructor Documentation

phot::TFLoaderMLP::TFLoaderMLP ( fhicl::ParameterSet const &  pset)
explicit

Definition at line 11 of file TFLoaderMLP_tool.cc.

References fhicl::ParameterSet::get(), InputsName, and OutputName.

12  : ModelName{pset.get<std::string>("ModelName")}
13  , InputsName{pset.get<std::vector<std::string>>("InputsName")}
14  , OutputName{pset.get<std::string>("OutputName")}
15  {}
std::string ModelName
Definition: TFLoaderMLP.h:24
std::string OutputName
Definition: TFLoaderMLP.h:26
std::vector< std::string > InputsName
Definition: TFLoaderMLP.h:25

Member Function Documentation

void phot::TFLoaderMLP::CloseSession ( )
virtual

Implements phot::TFLoader.

Definition at line 78 of file TFLoaderMLP_tool.cc.

References modelbundle, and status.

79  {
80  if (status.ok()) {
81  std::cout << "Close TF session." << std::endl;
82  // session->Close();
83  }
84 
85  delete modelbundle;
86 
87  // delete session;
88  return;
89  }
tensorflow::Status status
Definition: TFLoaderMLP.h:31
tensorflow::SavedModelBundleLite * modelbundle
Definition: TFLoaderMLP.h:28
std::vector<double> phot::TFLoader::GetPrediction ( ) const
inlineinherited

Definition at line 36 of file TFLoader.h.

References phot::TFLoader::prediction.

36 { return prediction; }
std::vector< double > prediction
Definition: TFLoader.h:39
void phot::TFLoaderMLP::Initialization ( )
virtual

Implements phot::TFLoader.

Definition at line 18 of file TFLoaderMLP_tool.cc.

References InputsName, modelbundle, ModelName, OutputName, and status.

19  {
20  int num_input = int(InputsName.size());
21  if (num_input != 3) {
22  std::cout << "Input name error! exit!" << std::endl;
23  return;
24  }
25  std::string GraphFileWithPath;
26  cet::search_path sp("FW_SEARCH_PATH");
27  if (!sp.find_file(ModelName, GraphFileWithPath)) {
28  throw cet::exception("TFLoaderMLP")
29  << "In larrecodnn:phot::TFLoaderMLP: Failed to load SavedModel in : " << sp.to_string()
30  << "\n";
31  }
32  std::cout << "larrecodnn:phot::TFLoaderMLP Loading TF Model from: " << GraphFileWithPath
33  << ", Input Layer: ";
34  for (int i = 0; i < num_input; ++i) {
35  std::cout << InputsName[i] << " ";
36  }
37  std::cout << ", Output Layer: " << OutputName << "\n";
38 
39  //Load SavedModel
40  modelbundle = new tensorflow::SavedModelBundleLite();
41 
42  status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(),
43  tensorflow::RunOptions(),
44  GraphFileWithPath,
45  {tensorflow::kSavedModelTagServe},
46  modelbundle);
47 
48  //Initialize a tensorflow session
49  // status = tensorflow::NewSession(tensorflow::SessionOptions(), &session);
50  if (!status.ok()) {
51  throw cet::exception("TFLoaderMLP")
52  << "In larrecodnn:phot::TFLoaderMLP: Failed to load SavedModel, status: "
53  << status.ToString() << std::endl;
54  }
55 
56  //Read in the protobuf graph
57  // tensorflow::GraphDef graph_def;
58  // status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), ModelName, &graph_def);
59  // if (!status.ok())
60  // {
61  // std::cout << status.ToString() << std::endl;
62  // return;
63  // }
64  //
65  // //Add the graph to the session
66  // status = session->Create(graph_def);
67  // if (!status.ok())
68  // {
69  // std::cout << status.ToString() << std::endl;
70  // return;
71  // }
72 
73  std::cout << "TF SavedModel loaded successfully." << std::endl;
74  return;
75  }
std::string ModelName
Definition: TFLoaderMLP.h:24
std::string OutputName
Definition: TFLoaderMLP.h:26
std::vector< std::string > InputsName
Definition: TFLoaderMLP.h:25
tensorflow::Status status
Definition: TFLoaderMLP.h:31
tensorflow::SavedModelBundleLite * modelbundle
Definition: TFLoaderMLP.h:28
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
void phot::TFLoaderMLP::Predict ( std::vector< double >  pars)
virtual

Implements phot::TFLoader.

Definition at line 92 of file TFLoaderMLP_tool.cc.

References DEFINE_ART_CLASS_TOOL, InputsName, modelbundle, OutputName, phot::TFLoader::prediction, status, and value.

93  {
94  //std::cout << "TFLoader MLP:: Predicting... " << std::endl;
95  int num_input = int(pars.size());
96  if (num_input != 3) {
97  std::cout << "Input parameter error! exit!" << std::endl;
98  return;
99  }
100  //Clean prediction
101  std::vector<double>().swap(prediction);
102 
103  //Define inputs
104  tensorflow::Tensor pos_x(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 1}));
105  tensorflow::Tensor pos_y(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 1}));
106  tensorflow::Tensor pos_z(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 1}));
107  auto dst_x = pos_x.flat<float>().data();
108  auto dst_y = pos_y.flat<float>().data();
109  auto dst_z = pos_z.flat<float>().data();
110  copy_n(pars.begin(), 1, dst_x);
111  copy_n(pars.begin() + 1, 1, dst_y);
112  copy_n(pars.begin() + 2, 1, dst_z);
113  std::vector<std::pair<std::string, tensorflow::Tensor>> inputs = {
114  {InputsName[0], pos_x}, {InputsName[1], pos_y}, {InputsName[2], pos_z}};
115  //Define outps
116  std::vector<tensorflow::Tensor> outputs;
117 
118  //Run the session
119  status = modelbundle->GetSession()->Run(inputs, {OutputName}, {}, &outputs);
120  // status = session->Run(inputs, {OutputName}, {}, &outputs);
121  if (!status.ok()) {
122  std::cout << status.ToString() << std::endl;
123  return;
124  }
125 
126  //Grab the outputs
127  unsigned int pdr = outputs[0].shape().dim_size(1);
128  //std::cout << "TFLoader MLP::Num of optical channels: " << pdr << std::endl;
129 
130  for (unsigned int i = 0; i < pdr; i++) {
131  double value = outputs[0].flat<float>()(i);
132  //std::cout << value << ", ";
133  prediction.push_back(value);
134  }
135  //std::cout << std::endl;
136  return;
137  }
std::string OutputName
Definition: TFLoaderMLP.h:26
std::vector< double > prediction
Definition: TFLoader.h:39
double value
Definition: spectrum.C:18
std::vector< std::string > InputsName
Definition: TFLoaderMLP.h:25
void swap(lar::deep_const_fwd_iterator_nested< CITER, INNERCONTEXTRACT > &a, lar::deep_const_fwd_iterator_nested< CITER, INNERCONTEXTRACT > &b)
tensorflow::Status status
Definition: TFLoaderMLP.h:31
tensorflow::SavedModelBundleLite * modelbundle
Definition: TFLoaderMLP.h:28

Member Data Documentation

std::vector<std::string> phot::TFLoaderMLP::InputsName
private

Definition at line 25 of file TFLoaderMLP.h.

Referenced by Initialization(), Predict(), and TFLoaderMLP().

tensorflow::SavedModelBundleLite* phot::TFLoaderMLP::modelbundle
private

Definition at line 28 of file TFLoaderMLP.h.

Referenced by CloseSession(), Initialization(), and Predict().

std::string phot::TFLoaderMLP::ModelName
private

Definition at line 24 of file TFLoaderMLP.h.

Referenced by Initialization().

std::string phot::TFLoaderMLP::OutputName
private

Definition at line 26 of file TFLoaderMLP.h.

Referenced by Initialization(), Predict(), and TFLoaderMLP().

std::vector<double> phot::TFLoader::prediction
protectedinherited

Definition at line 39 of file TFLoader.h.

Referenced by phot::TFLoader::GetPrediction(), and Predict().

tensorflow::Status phot::TFLoaderMLP::status
private

Definition at line 31 of file TFLoaderMLP.h.

Referenced by CloseSession(), Initialization(), and Predict().


The documentation for this class was generated from the following files: