LArSoft  v07_13_02
Liquid Argon Software toolkit - http://larsoft.org/
tf::Graph Class Reference

#include "tf_graph.h"

Public Member Functions

 ~Graph ()
 
std::vector< float > run (const std::vector< std::vector< float > > &x)
 
std::vector< std::vector< float > > run (const std::vector< std::vector< std::vector< std::vector< float > > > > &x, long long int samples=-1)
 
std::vector< std::vector< float > > run (const tensorflow::Tensor &x)
 

Static Public Member Functions

static std::unique_ptr< Graphcreate (const char *graph_file_name, const std::vector< std::string > &outputs={})
 

Private Member Functions

 Graph (const char *graph_file_name, const std::vector< std::string > &outputs, bool &success)
 Not-throwing constructor. More...
 

Private Attributes

tensorflow::Session * fSession
 
std::string fInputName
 
std::vector< std::string > fOutputNames
 

Detailed Description

Definition at line 26 of file tf_graph.h.

Constructor & Destructor Documentation

tf::Graph::~Graph ( )

Definition at line 83 of file tf_graph.cc.

References fSession.

84 {
85  fSession->Close();
86  delete fSession;
87 }
tensorflow::Session * fSession
Definition: tf_graph.h:52
tf::Graph::Graph ( const char *  graph_file_name,
const std::vector< std::string > &  outputs,
bool &  success 
)
private

Not-throwing constructor.

Definition at line 16 of file tf_graph.cc.

References fInputName, fOutputNames, fSession, n, and s.

17 {
18  success = false; // until all is done correctly
19 
20  auto status = tensorflow::NewSession(tensorflow::SessionOptions(), &fSession);
21  if (!status.ok())
22  {
23  std::cout << status.ToString() << std::endl;
24  return;
25  }
26 
27  tensorflow::GraphDef graph_def;
28  status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
29  if (!status.ok())
30  {
31  std::cout << status.ToString() << std::endl;
32  return;
33  }
34 
35  size_t ng = graph_def.node().size();
36  fInputName = graph_def.node()[0].name();
37 
38  // last node as output if no specific name provided
39  if (outputs.empty()) { fOutputNames.push_back(graph_def.node()[ng - 1].name()); }
40  else // or last nodes with names containing provided strings
41  {
42  std::string last, current, basename, name;
43  for (size_t n = 0; n < ng; ++n)
44  {
45  name = graph_def.node()[n].name();
46  auto pos = name.find("/");
47  if (pos != std::string::npos) { basename = name.substr(0, pos); }
48  else { continue; }
49 
50  bool found = false;
51  for (const auto & s : outputs)
52  {
53  if (name.find(s) != std::string::npos) { found = true; break; }
54  }
55  if (found)
56  {
57  if (!last.empty() && (basename != current))
58  {
59  fOutputNames.push_back(last);
60  }
61  current = basename;
62  last = name;
63  }
64  }
65  if (!last.empty()) { fOutputNames.push_back(last); }
66  }
67  if (fOutputNames.empty())
68  {
69  std::cout << "Output nodes not found in the graph." << std::endl;
70  return;
71  }
72 
73  status = fSession->Create(graph_def);
74  if (!status.ok())
75  {
76  std::cout << status.ToString() << std::endl;
77  return;
78  }
79 
80  success = true; // ok, graph loaded from the file
81 }
Float_t s
Definition: plot.C:23
std::string fInputName
Definition: tf_graph.h:53
std::vector< std::string > fOutputNames
Definition: tf_graph.h:54
tensorflow::Session * fSession
Definition: tf_graph.h:52
Char_t n[5]

Member Function Documentation

static std::unique_ptr<Graph> tf::Graph::create ( const char *  graph_file_name,
const std::vector< std::string > &  outputs = {} 
)
inlinestatic

Definition at line 29 of file tf_graph.h.

References lar::dump::vector(), and x.

Referenced by nnet::TfModelInterface::TfModelInterface().

29  {})
30  {
31  bool success;
32  std::unique_ptr<Graph> ptr(new Graph(graph_file_name, outputs, success));
33  if (success) { return ptr; }
34  else { return nullptr; }
35  }
Graph(const char *graph_file_name, const std::vector< std::string > &outputs, bool &success)
Not-throwing constructor.
Definition: tf_graph.cc:16
std::vector< float > tf::Graph::run ( const std::vector< std::vector< float > > &  x)

Definition at line 90 of file tf_graph.cc.

References x.

Referenced by run().

91 {
92  if (x.empty() || x.front().empty()) { return std::vector<float>(); }
93 
94  long long int rows = x.size(), cols = x.front().size();
95 
96  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({ 1, rows, cols, 1 }));
97  auto input_map = _x.tensor<float, 4>();
98 
99  for (long long int r = 0; r < rows; ++r) {
100  const auto & row = x[r];
101  for (long long int c = 0; c < cols; ++c) {
102  input_map(0, r, c, 0) = row[c];
103  }
104  }
105 
106  auto result = run(_x);
107  if (!result.empty()) { return result.front(); }
108  else { return std::vector<float>(); }
109 }
std::vector< float > run(const std::vector< std::vector< float > > &x)
Definition: tf_graph.cc:90
std::vector< std::vector< float > > tf::Graph::run ( const std::vector< std::vector< std::vector< std::vector< float > > > > &  x,
long long int  samples = -1 
)

Definition at line 112 of file tf_graph.cc.

References col, d, run(), s, lar::dump::vector(), and x.

115 {
116  if ((samples == 0) || x.empty() || x.front().empty() || x.front().front().empty() || x.front().front().front().empty())
117  return std::vector< std::vector<float> >();
118 
119  if ((samples == -1) || (samples > (long long int)x.size())) { samples = x.size(); }
120 
121  long long int
122  rows = x.front().size(),
123  cols = x.front().front().size(),
124  depth = x.front().front().front().size();
125 
126  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, rows, cols, depth }));
127  auto input_map = _x.tensor<float, 4>();
128  for (long long int s = 0; s < samples; ++s) {
129  const auto & sample = x[s];
130  for (long long int r = 0; r < rows; ++r) {
131  const auto & row = sample[r];
132  for (long long int c = 0; c < cols; ++c) {
133  const auto & col = row[c];
134  for (long long int d = 0; d < depth; ++d) {
135  input_map(s, r, c, d) = col[d];
136  }
137  }
138  }
139  }
140 
141  return run(_x);
142 }
Float_t s
Definition: plot.C:23
std::vector< float > run(const std::vector< std::vector< float > > &x)
Definition: tf_graph.cc:90
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:265
Int_t col[ntarg]
Definition: Style.C:29
Float_t d
Definition: plot.C:237
std::vector< std::vector< float > > tf::Graph::run ( const tensorflow::Tensor &  x)

Definition at line 145 of file tf_graph.cc.

References fInputName, fOutputNames, fSession, n, and s.

146 {
147  std::vector< std::pair<std::string, tensorflow::Tensor> > inputs = {
148  { fInputName, x }
149  };
150 
151  //std::cout << "run session" << std::endl;
152 
153  std::vector<tensorflow::Tensor> outputs;
154  auto status = fSession->Run(inputs, fOutputNames, {}, &outputs);
155 
156  //std::cout << "out size " << outputs.size() << std::endl;
157 
158  if (status.ok())
159  {
160  size_t samples = 0, nouts = 0;
161  for (size_t o = 0; o < outputs.size(); ++o)
162  {
163  if (o == 0) { samples = outputs[o].dim_size(0); }
164  else if ((int)samples != outputs[o].dim_size(0))
165  {
166  throw std::string("TF outputs size inconsistent.");
167  }
168  nouts += outputs[o].dim_size(1);
169  }
170  //std::cout << "samples " << samples << " nouts " << nouts << std::endl;
171 
172  std::vector< std::vector< float > > result;
173  result.resize(samples, std::vector< float >(nouts));
174 
175  size_t idx0 = 0;
176  for (size_t o = 0; o < outputs.size(); ++o)
177  {
178  auto output_map = outputs[o].tensor<float, 2>();
179 
180  size_t n = outputs[o].dim_size(1);
181  for (size_t s = 0; s < samples; ++s) {
182  std::vector< float > & vs = result[s];
183  for (size_t i = 0; i < n; ++i) {
184  vs[idx0 + i] = output_map(s, i);
185  }
186  }
187  idx0 += n;
188  }
189  return result;
190  }
191  else
192  {
193  std::cout << status.ToString() << std::endl;
194  return std::vector< std::vector< float > >();
195  }
196 }
Float_t x
Definition: compare.C:6
Float_t s
Definition: plot.C:23
std::string fInputName
Definition: tf_graph.h:53
std::vector< std::string > fOutputNames
Definition: tf_graph.h:54
tensorflow::Session * fSession
Definition: tf_graph.h:52
Char_t n[5]

Member Data Documentation

std::string tf::Graph::fInputName
private

Definition at line 53 of file tf_graph.h.

Referenced by Graph(), and run().

std::vector< std::string > tf::Graph::fOutputNames
private

Definition at line 54 of file tf_graph.h.

Referenced by Graph(), and run().

tensorflow::Session* tf::Graph::fSession
private

Definition at line 52 of file tf_graph.h.

Referenced by Graph(), run(), and ~Graph().


The documentation for this class was generated from the following files: