13 #include "larrecodnn/ImagePatternAlgs/ToolInterfaces/IPointIdAlg.h" 14 #include "tensorflow/core/public/session.h" 24 std::vector<float>
Run(
std::vector<std::vector<float>>
const& inp2d)
const override;
26 int samples = -1)
const override;
29 std::string
findFile(
const char* fileName)
const;
32 std::unique_ptr<tf::Graph>
g;
41 fNNetOutputs = table().NNetOutputs();
42 fPatchSizeW = table().PatchSizeW();
43 fPatchSizeD = table().PatchSizeD();
44 fCurrentWireIdx = 99999;
45 fCurrentScaledDrift = 99999;
53 std::vector<std::string> vs_cfgvr;
66 mf::LogError(
"PointIdAlgTf") <<
"File name extension not supported.";
75 std::string fname_out;
76 cet::search_path sp(
"FW_SEARCH_PATH");
77 if (!sp.find_file(fileName, fname_out)) {
79 if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
90 long long int rows = inp2d.size(), cols = inp2d.front().size();
92 std::vector<tensorflow::Tensor> _x;
94 tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, rows, cols, 1})));
95 auto input_map = _x[0].tensor<float, 4>();
96 for (
long long int r = 0;
r < rows; ++
r) {
97 const auto& row = inp2d[
r];
98 for (
long long int c = 0; c < cols; ++c) {
99 input_map(0,
r, c, 0) = row[c];
103 auto out =
g->runx(_x);
107 return std::vector<float>();
116 if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
117 return std::vector<std::vector<float>>();
120 if ((samples == -1) || (samples > (
long long int)inps.size())) { samples = inps.size(); }
122 long long int rows = inps.front().size(), cols = inps.front().front().size();
124 std::vector<tensorflow::Tensor> _x;
126 tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1})));
127 auto input_map = _x[0].tensor<float, 4>();
128 for (
long long int s = 0; s < samples; ++s) {
129 const auto& sample = inps[s];
130 for (
long long int r = 0;
r < rows; ++
r) {
131 const auto& row = sample[
r];
132 for (
long long int c = 0; c < cols; ++c) {
133 input_map(s,
r, c, 0) = row[c];
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception