3 #include "larrecodnn/ImagePatternAlgs/ToolInterfaces/IWireframeRecog.h" 5 #include "tensorflow/core/public/session.h" 19 std::unique_ptr<tf::Graph>
g;
31 pset.
get<std::vector<std::string>>(
"NNetOutputPattern", {
"cnn_output",
"dense_3"});
38 mf::LogInfo(
"WireframeRecogTf") <<
"TF model loaded.";
44 mf::LogInfo(
"WireframeRecogTf") <<
"TF model loaded.";
47 mf::LogError(
"WireframeRecogTf") <<
"File name extension not supported.";
50 setupWframeRecRoiParams(pset);
57 if (wireframes.empty() || wireframes.front().empty() || wireframes.front().front().empty()) {
58 return std::vector<std::vector<float>>();
62 long long int samples = wireframes.size(), rows = wireframes.front().size(),
63 cols = wireframes.front().front().size();
66 std::vector<tensorflow::Tensor> _x;
68 tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1})));
69 auto input_map = _x[0].tensor<float, 4>();
70 for (
long long int s = 0; s < samples; ++s) {
71 const auto& wframe = wireframes[s];
72 for (
long long int r = 0;
r < rows; ++
r) {
73 const auto& row = wframe[
r];
74 for (
long long int c = 0; c < cols; ++c) {
75 input_map(s,
r, c, 0) = float(row[c]);
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, bool use_bundle=false, int ninputs=1, int noutputs=1)
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
T get(std::string const &key) const
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception