LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
tf_graph.h
Go to the documentation of this file.
1 
12 #ifndef Graph_h
13 #define Graph_h
14 
15 #include <memory>
16 #include <string>
17 #include <vector>
18 
19 namespace tensorflow {
20  class Session;
21  class Tensor;
22  struct SavedModelBundle;
23 }
24 
25 namespace tf {
26 
27  class Graph {
28  public:
29  static std::unique_ptr<Graph> create(const char* graph_file_name,
30  const std::vector<std::string>& outputs = {},
31  bool use_bundle = false,
32  int ninputs = 1,
33  int noutputs = 1)
34  {
35  bool success;
36  std::unique_ptr<Graph> ptr(
37  new Graph(graph_file_name, outputs, success, use_bundle, ninputs, noutputs));
38  if (success) { return ptr; }
39  else {
40  return nullptr;
41  }
42  }
43 
44  ~Graph();
45 
46  std::vector<std::vector<float>> run(const std::vector<std::vector<float>>& x);
47 
48  // process vector of 3D inputs, return vector of 1D outputs; use all inputs
49  // if samples = -1, or only the specified number of first samples;
50  // can deal with multiple inputs
51  std::vector<std::vector<std::vector<float>>> run(
52  const std::vector<std::vector<std::vector<std::vector<float>>>>& x,
53  long long int samples = -1);
54  std::vector<std::vector<std::vector<float>>> run(const std::vector<tensorflow::Tensor>& x);
55  std::vector<std::vector<float>> runx(const std::vector<tensorflow::Tensor>& x);
56  std::vector<std::vector<float>> runae(const std::vector<tensorflow::Tensor>& x);
57 
58  private:
59  int n_inputs;
60  int n_outputs;
62  Graph(const char* graph_file_name,
63  const std::vector<std::string>& outputs,
64  bool& success,
65  bool use_bundle = false,
66  int ninputs = 1,
67  int noutputs = 1);
68 
69  tensorflow::Session* fSession;
70  bool fUseBundle;
71  tensorflow::SavedModelBundle* fBundle;
72  std::vector<std::string> fInputNames;
73  std::vector<std::string> fOutputNames;
74  };
75 
76 } // namespace tf
77 
78 #endif
Float_t x
Definition: compare.C:6
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:29
boost::adjacency_list< boost::vecS, boost::vecS, boost::bidirectionalS, vertex_property, edge_property, graph_property > Graph
Definition: ModuleGraph.h:22
std::vector< std::string > fOutputNames
Definition: tf_graph.h:73
Definition: tf_graph.h:25
int n_outputs
Definition: tf_graph.h:60
tensorflow::SavedModelBundle * fBundle
Definition: tf_graph.h:71
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
bool fUseBundle
Definition: tf_graph.h:70
int n_inputs
Definition: tf_graph.h:59
tensorflow::Session * fSession
Definition: tf_graph.h:69
std::vector< std::string > fInputNames
Definition: tf_graph.h:72