LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
tf::Graph Class Reference

#include "tf_graph.h"

Public Member Functions

 ~Graph ()
 
std::vector< std::vector< float > > run (const std::vector< std::vector< float >> &x)
 
std::vector< std::vector< std::vector< float > > > run (const std::vector< std::vector< std::vector< std::vector< float >>>> &x, long long int samples=-1)
 
std::vector< std::vector< std::vector< float > > > run (const std::vector< tensorflow::Tensor > &x)
 
std::vector< std::vector< float > > runx (const std::vector< tensorflow::Tensor > &x)
 
std::vector< std::vector< float > > runae (const std::vector< tensorflow::Tensor > &x)
 

Static Public Member Functions

static std::unique_ptr< Graphcreate (const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
 

Private Member Functions

 Graph (const char *graph_file_name, const std::vector< std::string > &outputs, bool &success, bool use_bundle=false, int ninputs=1, int noutputs=1)
 Not-throwing constructor. More...
 

Private Attributes

int n_inputs
 
int n_outputs
 
tensorflow::Session * fSession
 
bool fUseBundle
 
tensorflow::SavedModelBundle * fBundle
 
std::vector< std::string > fInputNames
 
std::vector< std::string > fOutputNames
 

Detailed Description

Definition at line 27 of file tf_graph.h.

Constructor & Destructor Documentation

tf::Graph::~Graph ( )

Definition at line 160 of file tf_graph.cc.

References fBundle, fSession, and fUseBundle.

161 {
162  fSession->Close().IgnoreError();
163  delete fSession;
164  if (fUseBundle) { delete fBundle; }
165 }
tensorflow::SavedModelBundle * fBundle
Definition: tf_graph.h:71
bool fUseBundle
Definition: tf_graph.h:70
tensorflow::Session * fSession
Definition: tf_graph.h:69
tf::Graph::Graph ( const char *  graph_file_name,
const std::vector< std::string > &  outputs,
bool &  success,
bool  use_bundle = false,
int  ninputs = 1,
int  noutputs = 1 
)
private

Not-throwing constructor.

Definition at line 21 of file tf_graph.cc.

References lariov::DataSource::Default, fBundle, fInputNames, fOutputNames, fSession, fUseBundle, n, n_inputs, and n_outputs.

27 {
28  fUseBundle = use_bundle;
29  success = false; // until all is done correctly
30 
31  n_inputs = ninputs;
32  n_outputs = noutputs;
33 
34  // Force tf to only use a single core so it doesn't eat batch farms
35  tensorflow::SessionOptions options;
36  tensorflow::ConfigProto& config = options.config;
37  config.set_inter_op_parallelism_threads(1);
38  config.set_intra_op_parallelism_threads(1);
39  config.set_use_per_session_threads(false);
40 
41  auto status = tensorflow::NewSession(options, &fSession);
42  if (!status.ok()) {
43  std::cout << status.ToString() << std::endl;
44  return;
45  }
46 
47  if (fUseBundle) {
48 
49  fBundle = new tensorflow::SavedModelBundle();
50  status = tensorflow::LoadSavedModel(tensorflow::SessionOptions(),
51  tensorflow::RunOptions(),
52  graph_file_name,
53  {tensorflow::kSavedModelTagServe},
54  fBundle);
55  std::cout << "tf_graph loaded SavedModelBundle with status: " << status.ToString() << std::endl;
56  if (!status.ok()) return;
57 
58  auto sig_map = fBundle->meta_graph_def.signature_def();
59  std::string sig_def = "serving_default";
60  bool has_default_key = false;
61  std::vector<std::string> sig_map_keys;
62  for (auto const& p : sig_map) {
63  if (p.first == sig_def) has_default_key = true;
64  sig_map_keys.push_back(p.first);
65  }
66  auto model_def = sig_map.at((has_default_key) ? sig_def : sig_map_keys.back());
67 
68  // ... Get the input names
69  std::cout << "tf_graph inputs:" << std::endl;
70  for (auto const& p : model_def.inputs()) {
71  fInputNames.push_back(p.second.name());
72  std::cout << "tf_graph InputName: " << fInputNames.back() << std::endl;
73  std::cout << " key: " << p.first << " value: " << p.second.name() << std::endl;
74  }
75 
76  // ... Get the output names
77  // .. get all outputs if no specific name provided
78  if (outputs.empty()) {
79  std::cout << "tf_graph using all outputs:" << std::endl;
80  for (auto const& p : model_def.outputs()) {
81  fOutputNames.push_back(p.second.name());
82  std::cout << " key: " << p.first << " value: " << p.second.name() << std::endl;
83  }
84  }
85  // .. or use only the outputs whose keys are specified
86  else {
87  std::cout << "tf_graph using selected outputs:" << std::endl;
88  for (const auto& s : outputs) {
89  for (auto const& p : model_def.outputs()) {
90  if (p.first == s) {
91  fOutputNames.push_back(p.second.name());
92  std::cout << " key: " << p.first << " value: " << p.second.name() << std::endl;
93  }
94  }
95  }
96  }
97  if (fOutputNames.empty()) {
98  std::cout << "tf_graph did not find outputs in SaveModelBundle." << std::endl;
99  return;
100  }
101  }
102  else {
103 
104  tensorflow::GraphDef graph_def;
105  status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
106  std::cout << "tf_graph loaded ProtoBuf graph with status: " << status.ToString() << std::endl;
107  if (!status.ok()) return;
108 
109  size_t ng = graph_def.node().size();
110  for (int i = 0; i < n_inputs; ++i) {
111  fInputNames.push_back(graph_def.node()[i].name());
112  }
113 
114  // last node as output if no specific name provided
115  if (outputs.empty()) {
116  for (int i = n_outputs; i > 0; --i) {
117  fOutputNames.push_back(graph_def.node()[ng - i].name());
118  }
119  }
120  else // or last nodes with names containing provided strings
121  {
122  std::string last, current, basename, name;
123  for (size_t n = 0; n < ng; ++n) {
124  name = graph_def.node()[n].name();
125  auto pos = name.find("/");
126  if (pos != std::string::npos) { basename = name.substr(0, pos); }
127  else {
128  continue;
129  }
130 
131  bool found = false;
132  for (const auto& s : outputs) {
133  if (name.find(s) != std::string::npos) {
134  found = true;
135  break;
136  }
137  }
138  if (found) {
139  if (!last.empty() && (basename != current)) { fOutputNames.push_back(last); }
140  current = basename;
141  last = name;
142  }
143  }
144  if (!last.empty()) { fOutputNames.push_back(last); }
145  }
146  if (fOutputNames.empty()) {
147  std::cout << "Output nodes not found in the graph." << std::endl;
148  return;
149  }
150  status = fSession->Create(graph_def);
151  if (!status.ok()) {
152  std::cout << status.ToString() << std::endl;
153  return;
154  }
155  }
156 
157  success = true; // ok, graph loaded from the file
158 }
std::vector< std::string > fOutputNames
Definition: tf_graph.h:73
int n_outputs
Definition: tf_graph.h:60
tensorflow::SavedModelBundle * fBundle
Definition: tf_graph.h:71
bool fUseBundle
Definition: tf_graph.h:70
int n_inputs
Definition: tf_graph.h:59
tensorflow::Session * fSession
Definition: tf_graph.h:69
Char_t n[5]
std::vector< std::string > fInputNames
Definition: tf_graph.h:72

Member Function Documentation

static std::unique_ptr<Graph> tf::Graph::create ( const char *  graph_file_name,
const std::vector< std::string > &  outputs = {},
bool  use_bundle = false,
int  ninputs = 1,
int  noutputs = 1 
)
inlinestatic

Definition at line 29 of file tf_graph.h.

References lar::dump::vector(), and x.

Referenced by PointIdAlgTools::PointIdAlgTf::PointIdAlgTf(), nnet::TfModelInterface::TfModelInterface(), wavdenoise_tool::WaveformDenoiseTf::WaveformDenoiseTf(), wavrec_tool::WaveformRecogTf::WaveformRecogTf(), and wframerec_tool::WireframeRecogTf::WireframeRecogTf().

30  {},
31  bool use_bundle = false,
32  int ninputs = 1,
33  int noutputs = 1)
34  {
35  bool success;
36  std::unique_ptr<Graph> ptr(
37  new Graph(graph_file_name, outputs, success, use_bundle, ninputs, noutputs));
38  if (success) { return ptr; }
39  else {
40  return nullptr;
41  }
42  }
Graph(const char *graph_file_name, const std::vector< std::string > &outputs, bool &success, bool use_bundle=false, int ninputs=1, int noutputs=1)
Not-throwing constructor.
Definition: tf_graph.cc:21
std::vector< std::vector< float > > tf::Graph::run ( const std::vector< std::vector< float >> &  x)

Definition at line 168 of file tf_graph.cc.

References r, and x.

Referenced by run().

169 {
170  if (x.empty() || x.front().empty()) { return std::vector<std::vector<float>>(); }
171 
172  long long int rows = x.size(), cols = x.front().size();
173 
174  std::vector<tensorflow::Tensor> _x;
175  _x.push_back(
176  tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, rows, cols, 1})));
177  auto input_map = _x[0].tensor<float, 4>();
178 
179  for (long long int r = 0; r < rows; ++r) {
180  const auto& row = x[r];
181  for (long long int c = 0; c < cols; ++c) {
182  input_map(0, r, c, 0) = row[c];
183  }
184  }
185 
186  auto result = run(_x);
187  if (!result.empty()) { return result.front(); }
188  else {
189  return std::vector<std::vector<float>>();
190  }
191 }
TRandom r
Definition: spectrum.C:23
std::vector< std::vector< float > > run(const std::vector< std::vector< float >> &x)
Definition: tf_graph.cc:168
std::vector< std::vector< std::vector< float > > > tf::Graph::run ( const std::vector< std::vector< std::vector< std::vector< float >>>> &  x,
long long int  samples = -1 
)

Definition at line 194 of file tf_graph.cc.

References col, d, n_inputs, r, run(), lar::dump::vector(), and x.

197 {
198  if ((samples == 0) || x.empty() || x.front().empty() || x.front().front().empty() ||
199  x.front().front().front().empty())
200  return std::vector<std::vector<std::vector<float>>>();
201 
202  if ((samples == -1) || (samples > (long long int)x.size())) { samples = x.size(); }
203 
204  long long int rows = x.front().size(), cols = x.front().front().size(),
205  depth = x.front().front().front().size();
206 
207  std::vector<tensorflow::Tensor> _x;
208 
209  // Single-input network
210  if (n_inputs == 1) {
211  _x.push_back(tensorflow::Tensor(tensorflow::DT_FLOAT,
212  tensorflow::TensorShape({samples, rows, cols, depth})));
213  auto input_map = _x[0].tensor<float, 4>();
214  for (long long int s = 0; s < samples; ++s) {
215  const auto& sample = x[s];
216  for (long long int r = 0; r < rows; ++r) {
217  const auto& row = sample[r];
218  for (long long int c = 0; c < cols; ++c) {
219  const auto& col = row[c];
220  for (long long int d = 0; d < depth; ++d) {
221  input_map(s, r, c, d) = col[d];
222  }
223  }
224  }
225  }
226  }
227  // Multi-input network
228  else {
229  for (int i = 0; i < depth; ++i) {
230  _x.push_back(tensorflow::Tensor(tensorflow::DT_FLOAT,
231  tensorflow::TensorShape({samples, rows, cols, 1})));
232  }
233 
234  for (int view = 0; view < depth; ++view) {
235  auto input_map = _x[view].tensor<float, 4>();
236  for (long long int s = 0; s < samples; ++s) {
237  const auto& sample = x[s];
238  for (long long int r = 0; r < rows; ++r) {
239  const auto& row = sample[r];
240  for (long long int c = 0; c < cols; ++c) {
241  const auto& col = row[c];
242  long long int d = view;
243  input_map(s, r, c, 0) = col[d];
244  }
245  }
246  }
247  }
248  }
249 
250  return run(_x);
251 }
TRandom r
Definition: spectrum.C:23
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
Int_t col[ntarg]
Definition: Style.C:29
Float_t d
Definition: plot.C:235
int n_inputs
Definition: tf_graph.h:59
std::vector< std::vector< float > > run(const std::vector< std::vector< float >> &x)
Definition: tf_graph.cc:168
std::vector< std::vector< std::vector< float > > > tf::Graph::run ( const std::vector< tensorflow::Tensor > &  x)

Definition at line 255 of file tf_graph.cc.

References fBundle, fInputNames, fOutputNames, fSession, fUseBundle, n, n_inputs, and lar::dump::vector().

257 {
258  std::vector<std::pair<std::string, tensorflow::Tensor>> inputs;
259  for (int i = 0; i < n_inputs; ++i) {
260  inputs.push_back({fInputNames[i], x[i]});
261  }
262 
263  std::vector<tensorflow::Tensor> outputs;
264  std::vector<std::string> outputNames;
265  auto status = (fUseBundle) ?
266  fBundle->GetSession()->Run(inputs, fOutputNames, outputNames, &outputs) :
267  fSession->Run(inputs, fOutputNames, outputNames, &outputs);
268 
269  if (status.ok()) {
270  size_t samples = 0;
271 
272  for (size_t o = 0; o < outputs.size(); ++o) {
273  if (o == 0) { samples = outputs[o].dim_size(0); }
274  else if ((int)samples != outputs[o].dim_size(0)) {
275  throw std::string("TF outputs size inconsistent.");
276  }
277  }
278 
279  std::vector<std::vector<std::vector<float>>> result;
280  result.resize(samples, std::vector<std::vector<float>>(outputs.size()));
281 
282  for (size_t s = 0; s < samples; ++s) {
283  for (size_t o = 0; o < outputs.size(); ++o) {
284  size_t n = outputs[o].dim_size(1);
285  auto output_map = outputs[o].tensor<float, 2>();
286 
287  result[s][o].resize(outputs[o].dim_size(1));
288 
289  std::vector<float>& vs = result[s][o];
290  for (size_t i = 0; i < n; ++i) {
291  vs[i] = output_map(s, i);
292  }
293  }
294  }
295 
296  return result;
297  }
298  else {
299  std::cout << status.ToString() << std::endl;
300  return std::vector<std::vector<std::vector<float>>>();
301  }
302 }
std::vector< std::string > fOutputNames
Definition: tf_graph.h:73
tensorflow::SavedModelBundle * fBundle
Definition: tf_graph.h:71
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
bool fUseBundle
Definition: tf_graph.h:70
int n_inputs
Definition: tf_graph.h:59
tensorflow::Session * fSession
Definition: tf_graph.h:69
Char_t n[5]
std::vector< std::string > fInputNames
Definition: tf_graph.h:72
std::vector< std::vector< float > > tf::Graph::runae ( const std::vector< tensorflow::Tensor > &  x)

Definition at line 355 of file tf_graph.cc.

References fBundle, fInputNames, fOutputNames, fSession, fUseBundle, and n_inputs.

356 {
357  std::vector<std::pair<std::string, tensorflow::Tensor>> inputs;
358  for (int i = 0; i < n_inputs; ++i) {
359  inputs.push_back({fInputNames[i], x[i]});
360  }
361 
362  std::vector<tensorflow::Tensor> outputs;
363  std::vector<std::string> outputNames;
364  auto status = (fUseBundle) ?
365  fBundle->GetSession()->Run(inputs, fOutputNames, outputNames, &outputs) :
366  fSession->Run(inputs, fOutputNames, outputNames, &outputs);
367 
368  if (status.ok()) {
369  size_t samples = 0, npoints = 0;
370 
371  if (outputs.size() > 1) { throw std::string("TF runae: detected more than one output."); }
372 
373  samples = outputs[0].dim_size(0);
374  npoints = outputs[0].dim_size(1);
375 
376  std::vector<std::vector<float>> result;
377  result.resize(samples, std::vector<float>(npoints));
378 
379  auto output_map = outputs[0].tensor<float, 3>();
380 
381  for (size_t s = 0; s < samples; ++s) {
382  std::vector<float>& vs = result[s];
383  for (size_t i = 0; i < npoints; ++i) {
384  vs[i] = output_map(s, i, 0);
385  }
386  }
387 
388  return result;
389  }
390  else {
391  std::cout << status.ToString() << std::endl;
392  return std::vector<std::vector<float>>();
393  }
394 }
std::vector< std::string > fOutputNames
Definition: tf_graph.h:73
tensorflow::SavedModelBundle * fBundle
Definition: tf_graph.h:71
bool fUseBundle
Definition: tf_graph.h:70
int n_inputs
Definition: tf_graph.h:59
tensorflow::Session * fSession
Definition: tf_graph.h:69
std::vector< std::string > fInputNames
Definition: tf_graph.h:72
std::vector< std::vector< float > > tf::Graph::runx ( const std::vector< tensorflow::Tensor > &  x)

Definition at line 305 of file tf_graph.cc.

References fBundle, fInputNames, fOutputNames, fSession, fUseBundle, n, and n_inputs.

306 {
307  std::vector<std::pair<std::string, tensorflow::Tensor>> inputs;
308  for (int i = 0; i < n_inputs; ++i) {
309  inputs.push_back({fInputNames[i], x[i]});
310  }
311 
312  std::vector<tensorflow::Tensor> outputs;
313  std::vector<std::string> outputNames;
314  auto status = (fUseBundle) ?
315  fBundle->GetSession()->Run(inputs, fOutputNames, outputNames, &outputs) :
316  fSession->Run(inputs, fOutputNames, outputNames, &outputs);
317 
318  if (status.ok()) {
319  size_t samples = 0, nouts = 0;
320 
321  for (size_t o = 0; o < outputs.size(); ++o) {
322  if (o == 0) { samples = outputs[o].dim_size(0); }
323  else if ((int)samples != outputs[o].dim_size(0)) {
324  throw std::string("TF outputs size inconsistent.");
325  }
326  nouts += outputs[o].dim_size(1);
327  }
328 
329  std::vector<std::vector<float>> result;
330  result.resize(samples, std::vector<float>(nouts));
331 
332  size_t idx0 = 0;
333  for (size_t o = 0; o < outputs.size(); ++o) {
334  auto output_map = outputs[o].tensor<float, 2>();
335 
336  size_t n = outputs[o].dim_size(1);
337  for (size_t s = 0; s < samples; ++s) {
338  std::vector<float>& vs = result[s];
339  for (size_t i = 0; i < n; ++i) {
340  vs[idx0 + i] = output_map(s, i);
341  }
342  }
343  idx0 += n;
344  }
345 
346  return result;
347  }
348  else {
349  std::cout << status.ToString() << std::endl;
350  return std::vector<std::vector<float>>();
351  }
352 }
std::vector< std::string > fOutputNames
Definition: tf_graph.h:73
tensorflow::SavedModelBundle * fBundle
Definition: tf_graph.h:71
bool fUseBundle
Definition: tf_graph.h:70
int n_inputs
Definition: tf_graph.h:59
tensorflow::Session * fSession
Definition: tf_graph.h:69
Char_t n[5]
std::vector< std::string > fInputNames
Definition: tf_graph.h:72

Member Data Documentation

tensorflow::SavedModelBundle* tf::Graph::fBundle
private

Definition at line 71 of file tf_graph.h.

Referenced by Graph(), run(), runae(), runx(), and ~Graph().

std::vector<std::string> tf::Graph::fInputNames
private

Definition at line 72 of file tf_graph.h.

Referenced by Graph(), run(), runae(), and runx().

std::vector<std::string> tf::Graph::fOutputNames
private

Definition at line 73 of file tf_graph.h.

Referenced by Graph(), run(), runae(), and runx().

tensorflow::Session* tf::Graph::fSession
private

Definition at line 69 of file tf_graph.h.

Referenced by Graph(), run(), runae(), runx(), and ~Graph().

bool tf::Graph::fUseBundle
private

Definition at line 70 of file tf_graph.h.

Referenced by Graph(), run(), runae(), runx(), and ~Graph().

int tf::Graph::n_inputs
private

Definition at line 59 of file tf_graph.h.

Referenced by Graph(), run(), runae(), and runx().

int tf::Graph::n_outputs
private

Definition at line 60 of file tf_graph.h.

Referenced by Graph().


The documentation for this class was generated from the following files: