36 #include "grpc_client.h" 46 #define FAIL_IF_ERR(X, MSG) \ 48 tc::Error err = (X); \ 50 std::cerr << "error: " << (MSG) << ": " << err << std::endl; \ 54 namespace tc = triton::client;
91 ,
minHits(p.get<
size_t>(
"minHits"))
92 ,
debug(p.get<
bool>(
"debug"))
93 ,
planes(p.get<vector<std::string>>(
"planes"))
103 verbose = tritonPset.
get<
bool>(
"verbose",
"false");
113 for (
auto const& tool_pset_labels : tool_psets.get_pset_names()) {
114 std::cout <<
"decoder lablel: " << tool_pset_labels << std::endl;
128 vector<art::Ptr<Hit>> hitlist;
129 vector<vector<size_t>> idsmap;
130 vector<NuGraphInput> graphinputs;
131 _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
133 if (
debug) std::cout <<
"Hits size=" << hitlist.size() << std::endl;
134 if (hitlist.size() <
minHits) {
145 const vector<int32_t>* hit_table_hit_id_data =
nullptr;
146 const vector<int32_t>* hit_table_local_plane_data =
nullptr;
147 const vector<float>* hit_table_local_time_data =
nullptr;
148 const vector<int32_t>* hit_table_local_wire_data =
nullptr;
149 const vector<float>* hit_table_integral_data =
nullptr;
150 const vector<float>* hit_table_rms_data =
nullptr;
151 const vector<int32_t>* spacepoint_table_spacepoint_id_data =
nullptr;
152 const vector<int32_t>* spacepoint_table_hit_id_u_data =
nullptr;
153 const vector<int32_t>* spacepoint_table_hit_id_v_data =
nullptr;
154 const vector<int32_t>* spacepoint_table_hit_id_y_data =
nullptr;
155 for (
const auto& gi : graphinputs) {
156 if (gi.input_name ==
"hit_table_hit_id")
157 hit_table_hit_id_data = &gi.input_int32_vec;
158 else if (gi.input_name ==
"hit_table_local_plane")
159 hit_table_local_plane_data = &gi.input_int32_vec;
160 else if (gi.input_name ==
"hit_table_local_time")
161 hit_table_local_time_data = &gi.input_float_vec;
162 else if (gi.input_name ==
"hit_table_local_wire")
163 hit_table_local_wire_data = &gi.input_int32_vec;
164 else if (gi.input_name ==
"hit_table_integral")
165 hit_table_integral_data = &gi.input_float_vec;
166 else if (gi.input_name ==
"hit_table_rms")
167 hit_table_rms_data = &gi.input_float_vec;
168 else if (gi.input_name ==
"spacepoint_table_spacepoint_id")
169 spacepoint_table_spacepoint_id_data = &gi.input_int32_vec;
170 else if (gi.input_name ==
"spacepoint_table_hit_id_u")
171 spacepoint_table_hit_id_u_data = &gi.input_int32_vec;
172 else if (gi.input_name ==
"spacepoint_table_hit_id_v")
173 spacepoint_table_hit_id_v_data = &gi.input_int32_vec;
174 else if (gi.input_name ==
"spacepoint_table_hit_id_y")
175 spacepoint_table_hit_id_y_data = &gi.input_int32_vec;
179 tc::Headers http_headers;
180 grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
181 bool test_use_cached_channel =
false;
182 bool use_cached_channel =
true;
186 std::unique_ptr<tc::InferenceServerGrpcClient> client;
187 tc::SslOptions ssl_options = tc::SslOptions();
193 err =
"unable to create secure grpc client";
196 err =
"unable to create grpc client";
199 int numRuns = test_use_cached_channel ? 2 : 1;
200 for (
int i = 0; i < numRuns; ++i) {
201 FAIL_IF_ERR(tc::InferenceServerGrpcClient::Create(&client,
206 tc::KeepAliveOptions(),
210 std::vector<int64_t> hit_table_shape{int64_t(hit_table_hit_id_data->size())};
211 std::vector<int64_t> spacepoint_table_shape{
212 int64_t(spacepoint_table_spacepoint_id_data->size())};
215 tc::InferInput* hit_table_hit_id;
216 tc::InferInput* hit_table_local_plane;
217 tc::InferInput* hit_table_local_time;
218 tc::InferInput* hit_table_local_wire;
219 tc::InferInput* hit_table_integral;
220 tc::InferInput* hit_table_rms;
222 tc::InferInput* spacepoint_table_spacepoint_id;
223 tc::InferInput* spacepoint_table_hit_id_u;
224 tc::InferInput* spacepoint_table_hit_id_v;
225 tc::InferInput* spacepoint_table_hit_id_y;
228 tc::InferInput::Create(&hit_table_hit_id,
"hit_table_hit_id", hit_table_shape,
"INT32"),
229 "unable to get hit_table_hit_id");
230 std::shared_ptr<tc::InferInput> hit_table_hit_id_ptr;
231 hit_table_hit_id_ptr.reset(hit_table_hit_id);
234 &hit_table_local_plane,
"hit_table_local_plane", hit_table_shape,
"INT32"),
235 "unable to get hit_table_local_plane");
236 std::shared_ptr<tc::InferInput> hit_table_local_plane_ptr;
237 hit_table_local_plane_ptr.reset(hit_table_local_plane);
240 &hit_table_local_time,
"hit_table_local_time", hit_table_shape,
"FP32"),
241 "unable to get hit_table_local_time");
242 std::shared_ptr<tc::InferInput> hit_table_local_time_ptr;
243 hit_table_local_time_ptr.reset(hit_table_local_time);
246 &hit_table_local_wire,
"hit_table_local_wire", hit_table_shape,
"INT32"),
247 "unable to get hit_table_local_wire");
248 std::shared_ptr<tc::InferInput> hit_table_local_wire_ptr;
249 hit_table_local_wire_ptr.reset(hit_table_local_wire);
252 tc::InferInput::Create(&hit_table_integral,
"hit_table_integral", hit_table_shape,
"FP32"),
253 "unable to get hit_table_integral");
254 std::shared_ptr<tc::InferInput> hit_table_integral_ptr;
255 hit_table_integral_ptr.reset(hit_table_integral);
257 FAIL_IF_ERR(tc::InferInput::Create(&hit_table_rms,
"hit_table_rms", hit_table_shape,
"FP32"),
258 "unable to get hit_table_rms");
259 std::shared_ptr<tc::InferInput> hit_table_rms_ptr;
260 hit_table_rms_ptr.reset(hit_table_rms);
262 FAIL_IF_ERR(tc::InferInput::Create(&spacepoint_table_spacepoint_id,
263 "spacepoint_table_spacepoint_id",
264 spacepoint_table_shape,
266 "unable to get spacepoint_table_spacepoint_id");
267 std::shared_ptr<tc::InferInput> spacepoint_table_spacepoint_id_ptr;
268 spacepoint_table_spacepoint_id_ptr.reset(spacepoint_table_spacepoint_id);
271 tc::InferInput::Create(
272 &spacepoint_table_hit_id_u,
"spacepoint_table_hit_id_u", spacepoint_table_shape,
"INT32"),
273 "unable to get spacepoint_table_spacepoint_hit_id_u");
274 std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_u_ptr;
275 spacepoint_table_hit_id_u_ptr.reset(spacepoint_table_hit_id_u);
278 tc::InferInput::Create(
279 &spacepoint_table_hit_id_v,
"spacepoint_table_hit_id_v", spacepoint_table_shape,
"INT32"),
280 "unable to get spacepoint_table_spacepoint_hit_id_v");
281 std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_v_ptr;
282 spacepoint_table_hit_id_v_ptr.reset(spacepoint_table_hit_id_v);
285 tc::InferInput::Create(
286 &spacepoint_table_hit_id_y,
"spacepoint_table_hit_id_y", spacepoint_table_shape,
"INT32"),
287 "unable to get spacepoint_table_spacepoint_hit_id_y");
288 std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_y_ptr;
289 spacepoint_table_hit_id_y_ptr.reset(spacepoint_table_hit_id_y);
292 reinterpret_cast<const uint8_t*>(hit_table_hit_id_data->data()),
293 hit_table_hit_id_data->size() *
sizeof(int32_t)),
294 "unable to set data for hit_table_hit_id");
297 reinterpret_cast<const uint8_t*>(hit_table_local_plane_data->data()),
298 hit_table_local_plane_data->size() *
sizeof(int32_t)),
299 "unable to set data for hit_table_local_plane");
302 reinterpret_cast<const uint8_t*>(hit_table_local_time_data->data()),
303 hit_table_local_time_data->size() *
sizeof(float)),
304 "unable to set data for hit_table_local_time");
307 reinterpret_cast<const uint8_t*>(hit_table_local_wire_data->data()),
308 hit_table_local_wire_data->size() *
sizeof(int32_t)),
309 "unable to set data for hit_table_local_wire");
312 reinterpret_cast<const uint8_t*>(hit_table_integral_data->data()),
313 hit_table_integral_data->size() *
sizeof(float)),
314 "unable to set data for hit_table_integral");
317 hit_table_rms_ptr->AppendRaw(reinterpret_cast<const uint8_t*>(hit_table_rms_data->data()),
318 hit_table_rms_data->size() *
sizeof(float)),
319 "unable to set data for hit_table_rms");
321 FAIL_IF_ERR(spacepoint_table_spacepoint_id_ptr->AppendRaw(
322 reinterpret_cast<const uint8_t*>(spacepoint_table_spacepoint_id_data->data()),
323 spacepoint_table_spacepoint_id_data->size() *
sizeof(int32_t)),
324 "unable to set data for spacepoint_table_spacepoint_id");
326 FAIL_IF_ERR(spacepoint_table_hit_id_u_ptr->AppendRaw(
327 reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_u_data->data()),
328 spacepoint_table_hit_id_u_data->size() *
sizeof(int32_t)),
329 "unable to set data for spacepoint_table_hit_id_u");
331 FAIL_IF_ERR(spacepoint_table_hit_id_v_ptr->AppendRaw(
332 reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_v_data->data()),
333 spacepoint_table_hit_id_v_data->size() *
sizeof(int32_t)),
334 "unable to set data for spacepoint_table_hit_id_v");
336 FAIL_IF_ERR(spacepoint_table_hit_id_y_ptr->AppendRaw(
337 reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_y_data->data()),
338 spacepoint_table_hit_id_y_data->size() *
sizeof(int32_t)),
339 "unable to set data for spacepoint_table_hit_id_y");
342 tc::InferRequestedOutput* x_semantic_u;
343 tc::InferRequestedOutput* x_semantic_v;
344 tc::InferRequestedOutput* x_semantic_y;
345 tc::InferRequestedOutput* x_filter_u;
346 tc::InferRequestedOutput* x_filter_v;
347 tc::InferRequestedOutput* x_filter_y;
349 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_u,
"x_semantic_u"),
350 "unable to get 'x_semantic_u'");
351 std::shared_ptr<tc::InferRequestedOutput> x_semantic_u_ptr;
352 x_semantic_u_ptr.reset(x_semantic_u);
354 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_v,
"x_semantic_v"),
355 "unable to get 'x_semantic_v'");
356 std::shared_ptr<tc::InferRequestedOutput> x_semantic_v_ptr;
357 x_semantic_v_ptr.reset(x_semantic_v);
359 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_y,
"x_semantic_y"),
360 "unable to get 'x_semantic_y'");
361 std::shared_ptr<tc::InferRequestedOutput> x_semantic_y_ptr;
362 x_semantic_y_ptr.reset(x_semantic_y);
364 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_u,
"x_filter_u"),
365 "unable to get 'x_filter_u'");
366 std::shared_ptr<tc::InferRequestedOutput> x_filter_u_ptr;
367 x_filter_u_ptr.reset(x_filter_u);
369 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_v,
"x_filter_v"),
370 "unable to get 'x_filter_v'");
371 std::shared_ptr<tc::InferRequestedOutput> x_filter_v_ptr;
372 x_filter_v_ptr.reset(x_filter_v);
374 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_y,
"x_filter_y"),
375 "unable to get 'x_filter_y'");
376 std::shared_ptr<tc::InferRequestedOutput> x_filter_y_ptr;
377 x_filter_y_ptr.reset(x_filter_y);
384 std::vector<tc::InferInput*> inputs = {hit_table_hit_id_ptr.get(),
385 hit_table_local_plane_ptr.get(),
386 hit_table_local_time_ptr.get(),
387 hit_table_local_wire_ptr.get(),
388 hit_table_integral_ptr.get(),
389 hit_table_rms_ptr.get(),
390 spacepoint_table_spacepoint_id_ptr.get(),
391 spacepoint_table_hit_id_u_ptr.get(),
392 spacepoint_table_hit_id_v_ptr.get(),
393 spacepoint_table_hit_id_y_ptr.get()};
395 std::vector<const tc::InferRequestedOutput*> outputs = {x_semantic_u_ptr.get(),
396 x_semantic_v_ptr.get(),
397 x_semantic_y_ptr.get(),
398 x_filter_u_ptr.get(),
399 x_filter_v_ptr.get(),
400 x_filter_y_ptr.get()};
402 tc::InferResult* results;
403 auto start = std::chrono::high_resolution_clock::now();
405 client->Infer(&results, options, inputs, outputs, http_headers, compression_algorithm),
406 "unable to run model");
407 auto end = std::chrono::high_resolution_clock::now();
408 std::chrono::duration<double> elapsed = end - start;
409 std::cout <<
"Time taken for inference: " << elapsed.count() <<
" seconds" << std::endl;
410 std::shared_ptr<tc::InferResult> results_ptr;
411 results_ptr.reset(results);
416 vector<NuGraphOutput> infer_output;
417 vector<string> outnames = {
418 "x_semantic_u",
"x_semantic_v",
"x_semantic_y",
"x_filter_u",
"x_filter_v",
"x_filter_y"};
419 for (
const auto& name : outnames) {
422 FAIL_IF_ERR(results_ptr->RawData(name, (
const uint8_t**)&_data, &_byte_size),
423 "unable to get result data for " + name);
424 size_t n_elements = _byte_size /
sizeof(float);
425 std::vector<float> out_data(_data, _data + n_elements);
vector< std::string > planes
Declaration of signal hit object.
EDProducer(fhicl::ParameterSet const &pset)
std::string ssl_certificate_chain
#define FAIL_IF_ERR(X, MSG)
NuGraphInferenceTriton & operator=(NuGraphInferenceTriton const &)=delete
NuGraphInferenceTriton(fhicl::ParameterSet const &p)
std::string model_version
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
auto vector(Vector const &v)
Returns a manipulator which will print the specified array.
auto array(Array const &a)
Returns a manipulator which will print the specified array.
#define DEFINE_ART_MODULE(klass)
T get(std::string const &key) const
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
std::string inference_url
std::unique_ptr< LoaderToolBase > _loaderTool
ProducesCollector & producesCollector() noexcept
std::string ssl_root_certificates
2D representation of charge deposited in the TDC/wire plane
std::string inference_model_name
void produce(art::Event &e) override
std::string ssl_private_key