LArSoft  v06_85_00
Liquid Argon Software toolkit - http://larsoft.org/
tf_graph.cc
Go to the documentation of this file.
1 // Class: Graph
3 // Authors: R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, Sept. 2017
4 // P.Plonski, from DUNE, WUT, Sept. 2017
5 //
6 // Iterface to run Tensorflow graph saved to a file. First attempts, quite functional.
7 //
9 
10 #include "tf_graph.h"
11 
12 #include "tensorflow/core/public/session.h"
13 #include "tensorflow/core/platform/env.h"
14 
15 // -------------------------------------------------------------------
16 tf::Graph::Graph(const char* graph_file_name, const std::vector<std::string> & outputs, bool & success)
17 {
18  success = false; // until all is done correctly
19 
20  auto status = tensorflow::NewSession(tensorflow::SessionOptions(), &fSession);
21  if (!status.ok())
22  {
23  std::cout << status.ToString() << std::endl;
24  return;
25  }
26 
27  tensorflow::GraphDef graph_def;
28  status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
29  if (!status.ok())
30  {
31  std::cout << status.ToString() << std::endl;
32  return;
33  }
34 
35  size_t ng = graph_def.node().size();
36  fInputName = graph_def.node()[0].name();
37 
38  // last node as output if no specific name provided
39  if (outputs.empty()) { fOutputNames.push_back(graph_def.node()[ng - 1].name()); }
40  else // or last nodes with names containing provided strings
41  {
42  std::string last, current, basename, name;
43  for (size_t n = 0; n < ng; ++n)
44  {
45  name = graph_def.node()[n].name();
46  auto pos = name.find("/");
47  if (pos != std::string::npos) { basename = name.substr(0, pos); }
48  else { continue; }
49 
50  bool found = false;
51  for (const auto & s : outputs)
52  {
53  if (name.find(s) != std::string::npos) { found = true; break; }
54  }
55  if (found)
56  {
57  if (!last.empty() && (basename != current))
58  {
59  fOutputNames.push_back(last);
60  }
61  current = basename;
62  last = name;
63  }
64  }
65  if (!last.empty()) { fOutputNames.push_back(last); }
66  }
67  if (fOutputNames.empty())
68  {
69  std::cout << "Output nodes not found in the graph." << std::endl;
70  return;
71  }
72 
73  status = fSession->Create(graph_def);
74  if (!status.ok())
75  {
76  std::cout << status.ToString() << std::endl;
77  return;
78  }
79 
80  success = true; // ok, graph loaded from the file
81 }
82 
84 {
85  fSession->Close();
86  delete fSession;
87 }
88 // -------------------------------------------------------------------
89 
90 std::vector<float> tf::Graph::run(const std::vector< std::vector<float> > & x)
91 {
92  if (x.empty() || x.front().empty()) { return std::vector<float>(); }
93 
94  long long int rows = x.size(), cols = x.front().size();
95 
96  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({ 1, rows, cols, 1 }));
97  auto input_map = _x.tensor<float, 4>();
98 
99  for (long long int r = 0; r < rows; ++r) {
100  const auto & row = x[r];
101  for (long long int c = 0; c < cols; ++c) {
102  input_map(0, r, c, 0) = row[c];
103  }
104  }
105 
106  auto result = run(_x);
107  if (!result.empty()) { return result.front(); }
108  else { return std::vector<float>(); }
109 }
110 // -------------------------------------------------------------------
111 
112 std::vector< std::vector<float> > tf::Graph::run(
113  const std::vector< std::vector< std::vector< std::vector<float> > > > & x,
114  long long int samples)
115 {
116  if ((samples == 0) || x.empty() || x.front().empty() || x.front().front().empty() || x.front().front().front().empty())
117  return std::vector< std::vector<float> >();
118 
119  if ((samples == -1) || (samples > (long long int)x.size())) { samples = x.size(); }
120 
121  long long int
122  rows = x.front().size(),
123  cols = x.front().front().size(),
124  depth = x.front().front().front().size();
125 
126  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, rows, cols, depth }));
127  auto input_map = _x.tensor<float, 4>();
128  for (long long int s = 0; s < samples; ++s) {
129  const auto & sample = x[s];
130  for (long long int r = 0; r < rows; ++r) {
131  const auto & row = sample[r];
132  for (long long int c = 0; c < cols; ++c) {
133  const auto & col = row[c];
134  for (long long int d = 0; d < depth; ++d) {
135  input_map(s, r, c, d) = col[d];
136  }
137  }
138  }
139  }
140 
141  return run(_x);
142 }
143 // -------------------------------------------------------------------
144 
145 std::vector< std::vector< float > > tf::Graph::run(const tensorflow::Tensor & x)
146 {
147  std::vector< std::pair<std::string, tensorflow::Tensor> > inputs = {
148  { fInputName, x }
149  };
150 
151  //std::cout << "run session" << std::endl;
152 
153  std::vector<tensorflow::Tensor> outputs;
154  auto status = fSession->Run(inputs, fOutputNames, {}, &outputs);
155 
156  //std::cout << "out size " << outputs.size() << std::endl;
157 
158  if (status.ok())
159  {
160  size_t samples = 0, nouts = 0;
161  for (size_t o = 0; o < outputs.size(); ++o)
162  {
163  if (o == 0) { samples = outputs[o].dim_size(0); }
164  else if ((int)samples != outputs[o].dim_size(0))
165  {
166  throw std::string("TF outputs size inconsistent.");
167  }
168  nouts += outputs[o].dim_size(1);
169  }
170  //std::cout << "samples " << samples << " nouts " << nouts << std::endl;
171 
172  std::vector< std::vector< float > > result;
173  result.resize(samples, std::vector< float >(nouts));
174 
175  size_t idx0 = 0;
176  for (size_t o = 0; o < outputs.size(); ++o)
177  {
178  auto output_map = outputs[o].tensor<float, 2>();
179 
180  size_t n = outputs[o].dim_size(1);
181  for (size_t s = 0; s < samples; ++s) {
182  std::vector< float > & vs = result[s];
183  for (size_t i = 0; i < n; ++i) {
184  vs[idx0 + i] = output_map(s, i);
185  }
186  }
187  idx0 += n;
188  }
189  return result;
190  }
191  else
192  {
193  std::cout << status.ToString() << std::endl;
194  return std::vector< std::vector< float > >();
195  }
196 }
197 // -------------------------------------------------------------------
198 
Float_t x
Definition: compare.C:6
Float_t s
Definition: plot.C:23
std::string fInputName
Definition: tf_graph.h:53
std::vector< float > run(const std::vector< std::vector< float > > &x)
Definition: tf_graph.cc:90
Graph(const char *graph_file_name, const std::vector< std::string > &outputs, bool &success)
Not-throwing constructor.
Definition: tf_graph.cc:16
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:265
Int_t col[ntarg]
Definition: Style.C:29
Float_t d
Definition: plot.C:237
std::vector< std::string > fOutputNames
Definition: tf_graph.h:54
tensorflow::Session * fSession
Definition: tf_graph.h:52
Char_t n[5]