LArSoft  v10_06_00
Liquid Argon Software toolkit - https://larsoft.org/
SemanticDecoder_tool.cc
Go to the documentation of this file.
1 #include "DecoderToolBase.h"
2 
4 
6 #include <torch/torch.h>
7 
10 
11 // fixme: this only works for 5 categories and should be extended to different sizes. This may require making the class templated.
13 
14 public:
21 
25  virtual ~SemanticDecoder() noexcept = default;
26 
32  void declareProducts(art::ProducesCollector& collector) override
33  {
34  collector.produces<vector<FeatureVector<5>>>(instancename);
36  }
37 
43  void writeEmptyToEvent(art::Event& e, const vector<vector<size_t>>& idsmap) override;
44 
50  void writeToEvent(art::Event& e,
51  const vector<vector<size_t>>& idsmap,
52  const vector<NuGraphOutput>& infer_output) override;
53 
54 private:
55  std::vector<std::string> categories;
57 };
58 
60  : DecoderToolBase(p)
61  , categories{p.get<std::vector<std::string>>("categories")}
62  , hitInput{p.get<art::InputTag>("hitInput")}
63 {}
64 
65 void SemanticDecoder::writeEmptyToEvent(art::Event& e, const vector<vector<size_t>>& idsmap)
66 {
67  //
68  auto semtdes = std::make_unique<MVADescription<5>>(hitInput.label(), instancename, categories);
69  e.put(std::move(semtdes), instancename);
70  //
71  size_t size = 0;
72  for (auto& v : idsmap)
73  size += v.size();
74  std::array<float, 5> arr;
75  std::fill(arr.begin(), arr.end(), -1.);
76  auto semtcol = std::make_unique<vector<FeatureVector<5>>>(size, FeatureVector<5>(arr));
77  e.put(std::move(semtcol), instancename);
78  //
79 }
80 
82  const vector<vector<size_t>>& idsmap,
83  const vector<NuGraphOutput>& infer_output)
84 {
85  //
86  auto semtdes = std::make_unique<MVADescription<5>>(hitInput.label(), instancename, categories);
87  e.put(std::move(semtdes), instancename);
88  //
89  size_t size = 0;
90  for (auto& v : idsmap)
91  size += v.size();
92  std::array<float, 5> arr;
93  std::fill(arr.begin(), arr.end(), -1.);
94  auto semtcol = std::make_unique<vector<FeatureVector<5>>>(size, FeatureVector<5>(arr));
95 
96  size_t n_cols = categories.size();
97  for (size_t p = 0; p < planes.size(); p++) {
98  //
99  const std::vector<float>* x_semantic_data = 0;
100  for (auto& io : infer_output) {
101  if (io.output_name == outputname + planes[p]) x_semantic_data = &io.output_vec;
102  }
103  if (debug) {
104  std::cout << outputname + planes[p] << std::endl;
105  printVector(*x_semantic_data);
106  }
107 
108  torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32);
109  size_t n_rows = x_semantic_data->size() / n_cols;
110  const torch::Tensor s =
111  torch::from_blob(const_cast<float*>(x_semantic_data->data()),
112  {static_cast<int64_t>(n_rows), static_cast<int64_t>(n_cols)},
113  options);
114 
115  for (int i = 0; i < s.sizes()[0]; ++i) {
116  size_t idx = idsmap[p][i];
117  std::array<float, 5> input;
118  for (size_t j = 0; j < n_cols; ++j)
119  input[j] = s[i][j].item<float>();
120  softmax(input);
121  FeatureVector<5> semt = FeatureVector<5>(input);
122  (*semtcol)[idx] = semt;
123  }
124  }
125  e.put(std::move(semtcol), instancename);
126 }
127 
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
vector< std::string > planes
art::InputTag hitInput
std::string outputname
void writeToEvent(art::Event &e, const vector< vector< size_t >> &idsmap, const vector< NuGraphOutput > &infer_output) override
Decoder function.
PutHandle< PROD > put(std::unique_ptr< PROD > &&edp, std::string const &instance={})
Definition: Event.h:77
decltype(auto) constexpr size(T &&obj)
ADL-aware version of std::size.
Definition: StdUtils.h:101
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
std::string const & label() const noexcept
Definition: InputTag.cc:79
void produces(std::string const &instanceName={}, Persistable const persistable=Persistable::Yes)
T get(std::string const &key) const
Definition: ParameterSet.h:314
SemanticDecoder(const fhicl::ParameterSet &pset)
Constructor.
std::vector< std::string > categories
void fill(const art::PtrVector< recob::Hit > &hits, int only_plane)
void printVector(const std::vector< float > &vec)
void writeEmptyToEvent(art::Event &e, const vector< vector< size_t >> &idsmap) override
writeEmptyToEvent function
virtual ~SemanticDecoder() noexcept=default
Virtual Destructor.
std::string instancename
void softmax(std::array< T, N > &arr)
Float_t e
Definition: plot.C:35
void declareProducts(art::ProducesCollector &collector) override
declareProducts function