LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
SemanticDecoder_tool.cc
Go to the documentation of this file.
1 #include "DecoderToolBase.h"
2 
4 #include <torch/torch.h>
5 
8 
9 // fixme: this only works for 5 categories and should be extended to different sizes. This may require making the class templated.
11 
12 public:
19 
23  virtual ~SemanticDecoder() noexcept = default;
24 
30  void declareProducts(art::ProducesCollector& collector) override
31  {
32  collector.produces<vector<FeatureVector<5>>>(instancename);
34  }
35 
41  void writeEmptyToEvent(art::Event& e, const vector<vector<size_t>>& idsmap) override;
42 
48  void writeToEvent(art::Event& e,
49  const vector<vector<size_t>>& idsmap,
50  const vector<NuGraphOutput>& infer_output) override;
51 
52 private:
53  std::vector<std::string> categories;
55 };
56 
58  : DecoderToolBase(p)
59  , categories{p.get<std::vector<std::string>>("categories")}
60  , hitInput{p.get<art::InputTag>("hitInput")}
61 {}
62 
63 void SemanticDecoder::writeEmptyToEvent(art::Event& e, const vector<vector<size_t>>& idsmap)
64 {
65  //
66  auto semtdes = std::make_unique<MVADescription<5>>(hitInput.label(), instancename, categories);
67  e.put(std::move(semtdes), instancename);
68  //
69  size_t size = 0;
70  for (auto& v : idsmap)
71  size += v.size();
72  std::array<float, 5> arr;
73  std::fill(arr.begin(), arr.end(), -1.);
74  auto semtcol = std::make_unique<vector<FeatureVector<5>>>(size, FeatureVector<5>(arr));
75  e.put(std::move(semtcol), instancename);
76  //
77 }
78 
80  const vector<vector<size_t>>& idsmap,
81  const vector<NuGraphOutput>& infer_output)
82 {
83  //
84  auto semtdes = std::make_unique<MVADescription<5>>(hitInput.label(), instancename, categories);
85  e.put(std::move(semtdes), instancename);
86  //
87  size_t size = 0;
88  for (auto& v : idsmap)
89  size += v.size();
90  std::array<float, 5> arr;
91  std::fill(arr.begin(), arr.end(), -1.);
92  auto semtcol = std::make_unique<vector<FeatureVector<5>>>(size, FeatureVector<5>(arr));
93 
94  size_t n_cols = categories.size();
95  for (size_t p = 0; p < planes.size(); p++) {
96  //
97  const std::vector<float>* x_semantic_data = 0;
98  for (auto& io : infer_output) {
99  if (io.output_name == outputname + planes[p]) x_semantic_data = &io.output_vec;
100  }
101  if (debug) {
102  std::cout << outputname + planes[p] << std::endl;
103  printVector(*x_semantic_data);
104  }
105 
106  torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32);
107  size_t n_rows = x_semantic_data->size() / n_cols;
108  const torch::Tensor s =
109  torch::from_blob(const_cast<float*>(x_semantic_data->data()),
110  {static_cast<int64_t>(n_rows), static_cast<int64_t>(n_cols)},
111  options);
112 
113  for (int i = 0; i < s.sizes()[0]; ++i) {
114  size_t idx = idsmap[p][i];
115  std::array<float, 5> input;
116  for (size_t j = 0; j < n_cols; ++j)
117  input[j] = s[i][j].item<float>();
118  softmax(input);
119  FeatureVector<5> semt = FeatureVector<5>(input);
120  (*semtcol)[idx] = semt;
121  }
122  }
123  e.put(std::move(semtcol), instancename);
124 }
125 
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
vector< std::string > planes
art::InputTag hitInput
std::string outputname
void writeToEvent(art::Event &e, const vector< vector< size_t >> &idsmap, const vector< NuGraphOutput > &infer_output) override
Decoder function.
PutHandle< PROD > put(std::unique_ptr< PROD > &&edp, std::string const &instance={})
Definition: Event.h:77
decltype(auto) constexpr size(T &&obj)
ADL-aware version of std::size.
Definition: StdUtils.h:101
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
std::string const & label() const noexcept
Definition: InputTag.cc:79
void produces(std::string const &instanceName={}, Persistable const persistable=Persistable::Yes)
T get(std::string const &key) const
Definition: ParameterSet.h:314
SemanticDecoder(const fhicl::ParameterSet &pset)
Constructor.
std::vector< std::string > categories
void fill(const art::PtrVector< recob::Hit > &hits, int only_plane)
void printVector(const std::vector< float > &vec)
void writeEmptyToEvent(art::Event &e, const vector< vector< size_t >> &idsmap) override
writeEmptyToEvent function
virtual ~SemanticDecoder() noexcept=default
Virtual Destructor.
std::string instancename
void softmax(std::array< T, N > &arr)
Float_t e
Definition: plot.C:35
void declareProducts(art::ProducesCollector &collector) override
declareProducts function