LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
TFLoaderMLP_tool.cc
Go to the documentation of this file.
1 // Class: TFLoaderMLP
3 // Plugin Type: tool
4 // File: TFLoaderMLP_tool.cc TFLoaderMLP.h
5 // Aug. 20, 2022 by Mu Wei (wmu@fnal.gov)
8 
9 namespace phot {
10  //......................................................................
12  : ModelName{pset.get<std::string>("ModelName")}
13  , InputsName{pset.get<std::vector<std::string>>("InputsName")}
14  , OutputName{pset.get<std::string>("OutputName")}
15  {}
16 
17  //......................................................................
19  {
20  int num_input = int(InputsName.size());
21  if (num_input != 3) {
22  std::cout << "Input name error! exit!" << std::endl;
23  return;
24  }
25  std::string GraphFileWithPath;
26  cet::search_path sp("FW_SEARCH_PATH");
27  if (!sp.find_file(ModelName, GraphFileWithPath)) {
28  throw cet::exception("TFLoaderMLP")
29  << "In larrecodnn:phot::TFLoaderMLP: Failed to load SavedModel in : " << sp.to_string()
30  << "\n";
31  }
32  std::cout << "larrecodnn:phot::TFLoaderMLP Loading TF Model from: " << GraphFileWithPath
33  << ", Input Layer: ";
34  for (int i = 0; i < num_input; ++i) {
35  std::cout << InputsName[i] << " ";
36  }
37  std::cout << ", Output Layer: " << OutputName << "\n";
38 
39  //Load SavedModel
40  modelbundle = new tensorflow::SavedModelBundleLite();
41 
42  status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(),
43  tensorflow::RunOptions(),
44  GraphFileWithPath,
45  {tensorflow::kSavedModelTagServe},
46  modelbundle);
47 
48  //Initialize a tensorflow session
49  // status = tensorflow::NewSession(tensorflow::SessionOptions(), &session);
50  if (!status.ok()) {
51  throw cet::exception("TFLoaderMLP")
52  << "In larrecodnn:phot::TFLoaderMLP: Failed to load SavedModel, status: "
53  << status.ToString() << std::endl;
54  }
55 
56  //Read in the protobuf graph
57  // tensorflow::GraphDef graph_def;
58  // status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), ModelName, &graph_def);
59  // if (!status.ok())
60  // {
61  // std::cout << status.ToString() << std::endl;
62  // return;
63  // }
64  //
65  // //Add the graph to the session
66  // status = session->Create(graph_def);
67  // if (!status.ok())
68  // {
69  // std::cout << status.ToString() << std::endl;
70  // return;
71  // }
72 
73  std::cout << "TF SavedModel loaded successfully." << std::endl;
74  return;
75  }
76 
77  //......................................................................
79  {
80  if (status.ok()) {
81  std::cout << "Close TF session." << std::endl;
82  // session->Close();
83  }
84 
85  delete modelbundle;
86 
87  // delete session;
88  return;
89  }
90 
91  //......................................................................
92  void TFLoaderMLP::Predict(std::vector<double> pars)
93  {
94  //std::cout << "TFLoader MLP:: Predicting... " << std::endl;
95  int num_input = int(pars.size());
96  if (num_input != 3) {
97  std::cout << "Input parameter error! exit!" << std::endl;
98  return;
99  }
100  //Clean prediction
101  std::vector<double>().swap(prediction);
102 
103  //Define inputs
104  tensorflow::Tensor pos_x(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 1}));
105  tensorflow::Tensor pos_y(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 1}));
106  tensorflow::Tensor pos_z(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 1}));
107  auto dst_x = pos_x.flat<float>().data();
108  auto dst_y = pos_y.flat<float>().data();
109  auto dst_z = pos_z.flat<float>().data();
110  copy_n(pars.begin(), 1, dst_x);
111  copy_n(pars.begin() + 1, 1, dst_y);
112  copy_n(pars.begin() + 2, 1, dst_z);
113  std::vector<std::pair<std::string, tensorflow::Tensor>> inputs = {
114  {InputsName[0], pos_x}, {InputsName[1], pos_y}, {InputsName[2], pos_z}};
115  //Define outps
116  std::vector<tensorflow::Tensor> outputs;
117 
118  //Run the session
119  status = modelbundle->GetSession()->Run(inputs, {OutputName}, {}, &outputs);
120  // status = session->Run(inputs, {OutputName}, {}, &outputs);
121  if (!status.ok()) {
122  std::cout << status.ToString() << std::endl;
123  return;
124  }
125 
126  //Grab the outputs
127  unsigned int pdr = outputs[0].shape().dim_size(1);
128  //std::cout << "TFLoader MLP::Num of optical channels: " << pdr << std::endl;
129 
130  for (unsigned int i = 0; i < pdr; i++) {
131  double value = outputs[0].flat<float>()(i);
132  //std::cout << value << ", ";
133  prediction.push_back(value);
134  }
135  //std::cout << std::endl;
136  return;
137  }
138 }
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
std::string ModelName
Definition: TFLoaderMLP.h:24
std::string OutputName
Definition: TFLoaderMLP.h:26
std::vector< double > prediction
Definition: TFLoader.h:39
T get(std::string const &key) const
Definition: ParameterSet.h:314
TFLoaderMLP(fhicl::ParameterSet const &pset)
double value
Definition: spectrum.C:18
General LArSoft Utilities.
std::vector< std::string > InputsName
Definition: TFLoaderMLP.h:25
tensorflow::Status status
Definition: TFLoaderMLP.h:31
tensorflow::SavedModelBundleLite * modelbundle
Definition: TFLoaderMLP.h:28
void Predict(std::vector< double > pars)
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33