LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
PointIdAlgSonicTriton_tool.cc
Go to the documentation of this file.
1 // Class: PointIdAlgSonicTriton_tool
3 // Authors: M.Wang
5 
8 #include "fhiclcpp/types/Table.h"
10 #include "larrecodnn/ImagePatternAlgs/ToolInterfaces/IPointIdAlg.h"
12 
13 #include <memory>
14 #include <string>
15 #include <vector>
16 
17 namespace PointIdAlgTools {
18 
19  class PointIdAlgSonicTriton : public IPointIdAlg {
20  public:
21  explicit PointIdAlgSonicTriton(fhicl::Table<Config> const& table);
22 
23  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) const override;
24  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
25  int samples = -1) const override;
26 
27  private:
28  std::string fTritonModelName;
29  std::string fTritonURL;
31  std::string fTritonModelVersion;
32  unsigned fTritonTimeout;
34 
35  std::unique_ptr<lartriton::TritonClient> triton_client;
36  };
37 
38  // ------------------------------------------------------
40  : img::DataProviderAlg(table())
41  {
42  // ... Get common config vars
43  fNNetOutputs = table().NNetOutputs();
44  fPatchSizeW = table().PatchSizeW();
45  fPatchSizeD = table().PatchSizeD();
46  fCurrentWireIdx = 99999;
47  fCurrentScaledDrift = 99999;
48 
49  // ... Get "optional" config vars specific to tRTis interface
50  fTritonModelName = table().TritonModelName();
51  fTritonURL = table().TritonURL();
52  fTritonVerbose = table().TritonVerbose();
53  fTritonModelVersion = table().TritonModelVersion();
54  fTritonAllowedTries = table().TritonAllowedTries();
55 
56  // ... Create parameter set for Triton inference client
57  fhicl::ParameterSet TritonPset;
58  TritonPset.put("serverURL", fTritonURL);
59  TritonPset.put("verbose", fTritonVerbose);
60  TritonPset.put("modelName", fTritonModelName);
61  TritonPset.put("modelVersion", fTritonModelVersion);
62  TritonPset.put("timeout", fTritonTimeout);
63  TritonPset.put("allowedTries", fTritonAllowedTries);
64  TritonPset.put("outputs", "[]");
65 
66  // ... Create the Triton inference client
67  triton_client = std::make_unique<lartriton::TritonClient>(TritonPset);
68 
69  mf::LogInfo("PointIdAlgSonicTriton") << "url: " << fTritonURL;
70  mf::LogInfo("PointIdAlgSonicTriton") << "model name: " << fTritonModelName;
71  mf::LogInfo("PointIdAlgSonicTriton") << "model version: " << fTritonModelVersion;
72  mf::LogInfo("PointIdAlgSonicTriton") << "verbose: " << fTritonVerbose;
73 
74  mf::LogInfo("PointIdAlgSonicTriton") << "tensorRT inference context created.";
75 
76  resizePatch();
77  }
78 
79  // ------------------------------------------------------
80  std::vector<float> PointIdAlgSonicTriton::Run(std::vector<std::vector<float>> const& inp2d) const
81  {
82  size_t nrows = inp2d.size();
83 
84  triton_client->setBatchSize(1); // set batch size
85 
86  // ~~~~ Initialize the inputs
87  auto& triton_input = triton_client->input().begin()->second;
88 
89  auto data1 = std::make_shared<lartriton::TritonInput<float>>();
90  data1->reserve(1);
91 
92  // ~~~~ Prepare image for sending to server
93  auto& img = data1->emplace_back();
94  // ..first flatten the 2d array into contiguous 1d block
95  for (size_t ir = 0; ir < nrows; ++ir) {
96  img.insert(img.end(), inp2d[ir].begin(), inp2d[ir].end());
97  }
98 
99  triton_input.toServer(data1); // convert to server format
100 
101  // ~~~~ Send inference request
102  triton_client->dispatch();
103 
104  // ~~~~ Retrieve inference results
105  const auto& triton_output0 = triton_client->output().at("em_trk_none_netout/Softmax");
106  const auto& prob0 = triton_output0.fromServer<float>();
107  auto ncat0 = triton_output0.sizeDims();
108 
109  const auto& triton_output1 = triton_client->output().at("michel_netout/Sigmoid");
110  const auto& prob1 = triton_output1.fromServer<float>();
111  auto ncat1 = triton_output1.sizeDims();
112 
113  std::vector<float> out;
114  out.reserve(ncat0 + ncat1);
115  out.insert(out.end(), prob0[0].begin(), prob0[0].end());
116  out.insert(out.end(), prob1[0].begin(), prob1[0].end());
117 
118  triton_client->reset();
119 
120  return out;
121  }
122 
123  // ------------------------------------------------------
124  std::vector<std::vector<float>> PointIdAlgSonicTriton::Run(
125  std::vector<std::vector<std::vector<float>>> const& inps,
126  int samples) const
127  {
128  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
129  return std::vector<std::vector<float>>();
130  }
131 
132  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
133 
134  size_t usamples = samples;
135  size_t nrows = inps.front().size();
136 
137  triton_client->setBatchSize(usamples); // set batch size
138 
139  // ~~~~ Initialize the inputs
140  auto& triton_input = triton_client->input().begin()->second;
141 
142  auto data1 = std::make_shared<lartriton::TritonInput<float>>();
143  data1->reserve(usamples);
144 
145  // ~~~~ For each sample, prepare images for sending to server
146  for (size_t idx = 0; idx < usamples; ++idx) {
147  auto& img = data1->emplace_back();
148  // ..first flatten the 2d array into contiguous 1d block
149  for (size_t ir = 0; ir < nrows; ++ir) {
150  img.insert(img.end(), inps[idx][ir].begin(), inps[idx][ir].end());
151  }
152  }
153  triton_input.toServer(data1); // convert to server format
154 
155  // ~~~~ Send inference request
156  triton_client->dispatch();
157 
158  // ~~~~ Retrieve inference results
159  const auto& triton_output0 = triton_client->output().at("em_trk_none_netout/Softmax");
160  const auto& prob0 = triton_output0.fromServer<float>();
161  auto ncat0 = triton_output0.sizeDims();
162 
163  const auto& triton_output1 = triton_client->output().at("michel_netout/Sigmoid");
164  const auto& prob1 = triton_output1.fromServer<float>();
165  auto ncat1 = triton_output1.sizeDims();
166 
167  std::vector<std::vector<float>> out;
168  out.reserve(usamples);
169  for (unsigned i = 0; i < usamples; i++) {
170  out.emplace_back();
171  auto& img = out.back();
172  img.reserve(ncat0 + ncat1);
173  img.insert(img.end(), prob0[i].begin(), prob0[i].end());
174  img.insert(img.end(), prob1[i].begin(), prob1[i].end());
175  }
176 
177  triton_client->reset();
178 
179  return out;
180  }
181 
182 }
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
PointIdAlgSonicTriton(fhicl::Table< Config > const &table)
std::vector< float > Run(std::vector< std::vector< float >> const &inp2d) const override
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
Definition: DumpUtils.h:289
std::unique_ptr< lartriton::TritonClient > triton_client
void put(std::string const &key)