127 vector<art::Ptr<Hit>> hitlist;
128 vector<vector<size_t>> idsmap;
129 vector<NuGraphInput> graphinputs;
130 _loaderTool->loadData(e, hitlist, graphinputs, idsmap);
132 if (
debug) std::cout <<
"Hits size=" << hitlist.size() << std::endl;
133 if (hitlist.size() <
minHits) {
144 const vector<int32_t>* hit_table_hit_id_data =
nullptr;
145 const vector<int32_t>* hit_table_local_plane_data =
nullptr;
146 const vector<float>* hit_table_local_time_data =
nullptr;
147 const vector<int32_t>* hit_table_local_wire_data =
nullptr;
148 const vector<float>* hit_table_integral_data =
nullptr;
149 const vector<float>* hit_table_rms_data =
nullptr;
150 const vector<int32_t>* spacepoint_table_spacepoint_id_data =
nullptr;
151 const vector<int32_t>* spacepoint_table_hit_id_u_data =
nullptr;
152 const vector<int32_t>* spacepoint_table_hit_id_v_data =
nullptr;
153 const vector<int32_t>* spacepoint_table_hit_id_y_data =
nullptr;
154 for (
const auto& gi : graphinputs) {
155 if (gi.input_name ==
"hit_table_hit_id")
156 hit_table_hit_id_data = &gi.input_int32_vec;
157 else if (gi.input_name ==
"hit_table_local_plane")
158 hit_table_local_plane_data = &gi.input_int32_vec;
159 else if (gi.input_name ==
"hit_table_local_time")
160 hit_table_local_time_data = &gi.input_float_vec;
161 else if (gi.input_name ==
"hit_table_local_wire")
162 hit_table_local_wire_data = &gi.input_int32_vec;
163 else if (gi.input_name ==
"hit_table_integral")
164 hit_table_integral_data = &gi.input_float_vec;
165 else if (gi.input_name ==
"hit_table_rms")
166 hit_table_rms_data = &gi.input_float_vec;
167 else if (gi.input_name ==
"spacepoint_table_spacepoint_id")
168 spacepoint_table_spacepoint_id_data = &gi.input_int32_vec;
169 else if (gi.input_name ==
"spacepoint_table_hit_id_u")
170 spacepoint_table_hit_id_u_data = &gi.input_int32_vec;
171 else if (gi.input_name ==
"spacepoint_table_hit_id_v")
172 spacepoint_table_hit_id_v_data = &gi.input_int32_vec;
173 else if (gi.input_name ==
"spacepoint_table_hit_id_y")
174 spacepoint_table_hit_id_y_data = &gi.input_int32_vec;
178 tc::Headers http_headers;
179 grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
180 bool test_use_cached_channel =
false;
181 bool use_cached_channel =
true;
185 std::unique_ptr<tc::InferenceServerGrpcClient> client;
186 tc::SslOptions ssl_options = tc::SslOptions();
192 err =
"unable to create secure grpc client";
195 err =
"unable to create grpc client";
198 int numRuns = test_use_cached_channel ? 2 : 1;
199 for (
int i = 0; i < numRuns; ++i) {
200 FAIL_IF_ERR(tc::InferenceServerGrpcClient::Create(&client,
205 tc::KeepAliveOptions(),
209 std::vector<int64_t> hit_table_shape{int64_t(hit_table_hit_id_data->size())};
210 std::vector<int64_t> spacepoint_table_shape{
211 int64_t(spacepoint_table_spacepoint_id_data->size())};
214 tc::InferInput* hit_table_hit_id;
215 tc::InferInput* hit_table_local_plane;
216 tc::InferInput* hit_table_local_time;
217 tc::InferInput* hit_table_local_wire;
218 tc::InferInput* hit_table_integral;
219 tc::InferInput* hit_table_rms;
221 tc::InferInput* spacepoint_table_spacepoint_id;
222 tc::InferInput* spacepoint_table_hit_id_u;
223 tc::InferInput* spacepoint_table_hit_id_v;
224 tc::InferInput* spacepoint_table_hit_id_y;
227 tc::InferInput::Create(&hit_table_hit_id,
"hit_table_hit_id", hit_table_shape,
"INT32"),
228 "unable to get hit_table_hit_id");
229 std::shared_ptr<tc::InferInput> hit_table_hit_id_ptr;
230 hit_table_hit_id_ptr.reset(hit_table_hit_id);
233 &hit_table_local_plane,
"hit_table_local_plane", hit_table_shape,
"INT32"),
234 "unable to get hit_table_local_plane");
235 std::shared_ptr<tc::InferInput> hit_table_local_plane_ptr;
236 hit_table_local_plane_ptr.reset(hit_table_local_plane);
239 &hit_table_local_time,
"hit_table_local_time", hit_table_shape,
"FP32"),
240 "unable to get hit_table_local_time");
241 std::shared_ptr<tc::InferInput> hit_table_local_time_ptr;
242 hit_table_local_time_ptr.reset(hit_table_local_time);
245 &hit_table_local_wire,
"hit_table_local_wire", hit_table_shape,
"INT32"),
246 "unable to get hit_table_local_wire");
247 std::shared_ptr<tc::InferInput> hit_table_local_wire_ptr;
248 hit_table_local_wire_ptr.reset(hit_table_local_wire);
251 tc::InferInput::Create(&hit_table_integral,
"hit_table_integral", hit_table_shape,
"FP32"),
252 "unable to get hit_table_integral");
253 std::shared_ptr<tc::InferInput> hit_table_integral_ptr;
254 hit_table_integral_ptr.reset(hit_table_integral);
256 FAIL_IF_ERR(tc::InferInput::Create(&hit_table_rms,
"hit_table_rms", hit_table_shape,
"FP32"),
257 "unable to get hit_table_rms");
258 std::shared_ptr<tc::InferInput> hit_table_rms_ptr;
259 hit_table_rms_ptr.reset(hit_table_rms);
261 FAIL_IF_ERR(tc::InferInput::Create(&spacepoint_table_spacepoint_id,
262 "spacepoint_table_spacepoint_id",
263 spacepoint_table_shape,
265 "unable to get spacepoint_table_spacepoint_id");
266 std::shared_ptr<tc::InferInput> spacepoint_table_spacepoint_id_ptr;
267 spacepoint_table_spacepoint_id_ptr.reset(spacepoint_table_spacepoint_id);
270 tc::InferInput::Create(
271 &spacepoint_table_hit_id_u,
"spacepoint_table_hit_id_u", spacepoint_table_shape,
"INT32"),
272 "unable to get spacepoint_table_spacepoint_hit_id_u");
273 std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_u_ptr;
274 spacepoint_table_hit_id_u_ptr.reset(spacepoint_table_hit_id_u);
277 tc::InferInput::Create(
278 &spacepoint_table_hit_id_v,
"spacepoint_table_hit_id_v", spacepoint_table_shape,
"INT32"),
279 "unable to get spacepoint_table_spacepoint_hit_id_v");
280 std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_v_ptr;
281 spacepoint_table_hit_id_v_ptr.reset(spacepoint_table_hit_id_v);
284 tc::InferInput::Create(
285 &spacepoint_table_hit_id_y,
"spacepoint_table_hit_id_y", spacepoint_table_shape,
"INT32"),
286 "unable to get spacepoint_table_spacepoint_hit_id_y");
287 std::shared_ptr<tc::InferInput> spacepoint_table_hit_id_y_ptr;
288 spacepoint_table_hit_id_y_ptr.reset(spacepoint_table_hit_id_y);
291 reinterpret_cast<const uint8_t*>(hit_table_hit_id_data->data()),
292 hit_table_hit_id_data->size() *
sizeof(int32_t)),
293 "unable to set data for hit_table_hit_id");
296 reinterpret_cast<const uint8_t*>(hit_table_local_plane_data->data()),
297 hit_table_local_plane_data->size() *
sizeof(int32_t)),
298 "unable to set data for hit_table_local_plane");
301 reinterpret_cast<const uint8_t*>(hit_table_local_time_data->data()),
302 hit_table_local_time_data->size() *
sizeof(float)),
303 "unable to set data for hit_table_local_time");
306 reinterpret_cast<const uint8_t*>(hit_table_local_wire_data->data()),
307 hit_table_local_wire_data->size() *
sizeof(int32_t)),
308 "unable to set data for hit_table_local_wire");
311 reinterpret_cast<const uint8_t*>(hit_table_integral_data->data()),
312 hit_table_integral_data->size() *
sizeof(float)),
313 "unable to set data for hit_table_integral");
316 hit_table_rms_ptr->AppendRaw(reinterpret_cast<const uint8_t*>(hit_table_rms_data->data()),
317 hit_table_rms_data->size() *
sizeof(float)),
318 "unable to set data for hit_table_rms");
320 FAIL_IF_ERR(spacepoint_table_spacepoint_id_ptr->AppendRaw(
321 reinterpret_cast<const uint8_t*>(spacepoint_table_spacepoint_id_data->data()),
322 spacepoint_table_spacepoint_id_data->size() *
sizeof(int32_t)),
323 "unable to set data for spacepoint_table_spacepoint_id");
325 FAIL_IF_ERR(spacepoint_table_hit_id_u_ptr->AppendRaw(
326 reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_u_data->data()),
327 spacepoint_table_hit_id_u_data->size() *
sizeof(int32_t)),
328 "unable to set data for spacepoint_table_hit_id_u");
330 FAIL_IF_ERR(spacepoint_table_hit_id_v_ptr->AppendRaw(
331 reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_v_data->data()),
332 spacepoint_table_hit_id_v_data->size() *
sizeof(int32_t)),
333 "unable to set data for spacepoint_table_hit_id_v");
335 FAIL_IF_ERR(spacepoint_table_hit_id_y_ptr->AppendRaw(
336 reinterpret_cast<const uint8_t*>(spacepoint_table_hit_id_y_data->data()),
337 spacepoint_table_hit_id_y_data->size() *
sizeof(int32_t)),
338 "unable to set data for spacepoint_table_hit_id_y");
341 tc::InferRequestedOutput* x_semantic_u;
342 tc::InferRequestedOutput* x_semantic_v;
343 tc::InferRequestedOutput* x_semantic_y;
344 tc::InferRequestedOutput* x_filter_u;
345 tc::InferRequestedOutput* x_filter_v;
346 tc::InferRequestedOutput* x_filter_y;
348 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_u,
"x_semantic_u"),
349 "unable to get 'x_semantic_u'");
350 std::shared_ptr<tc::InferRequestedOutput> x_semantic_u_ptr;
351 x_semantic_u_ptr.reset(x_semantic_u);
353 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_v,
"x_semantic_v"),
354 "unable to get 'x_semantic_v'");
355 std::shared_ptr<tc::InferRequestedOutput> x_semantic_v_ptr;
356 x_semantic_v_ptr.reset(x_semantic_v);
358 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_semantic_y,
"x_semantic_y"),
359 "unable to get 'x_semantic_y'");
360 std::shared_ptr<tc::InferRequestedOutput> x_semantic_y_ptr;
361 x_semantic_y_ptr.reset(x_semantic_y);
363 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_u,
"x_filter_u"),
364 "unable to get 'x_filter_u'");
365 std::shared_ptr<tc::InferRequestedOutput> x_filter_u_ptr;
366 x_filter_u_ptr.reset(x_filter_u);
368 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_v,
"x_filter_v"),
369 "unable to get 'x_filter_v'");
370 std::shared_ptr<tc::InferRequestedOutput> x_filter_v_ptr;
371 x_filter_v_ptr.reset(x_filter_v);
373 FAIL_IF_ERR(tc::InferRequestedOutput::Create(&x_filter_y,
"x_filter_y"),
374 "unable to get 'x_filter_y'");
375 std::shared_ptr<tc::InferRequestedOutput> x_filter_y_ptr;
376 x_filter_y_ptr.reset(x_filter_y);
383 std::vector<tc::InferInput*> inputs = {hit_table_hit_id_ptr.get(),
384 hit_table_local_plane_ptr.get(),
385 hit_table_local_time_ptr.get(),
386 hit_table_local_wire_ptr.get(),
387 hit_table_integral_ptr.get(),
388 hit_table_rms_ptr.get(),
389 spacepoint_table_spacepoint_id_ptr.get(),
390 spacepoint_table_hit_id_u_ptr.get(),
391 spacepoint_table_hit_id_v_ptr.get(),
392 spacepoint_table_hit_id_y_ptr.get()};
394 std::vector<const tc::InferRequestedOutput*> outputs = {x_semantic_u_ptr.get(),
395 x_semantic_v_ptr.get(),
396 x_semantic_y_ptr.get(),
397 x_filter_u_ptr.get(),
398 x_filter_v_ptr.get(),
399 x_filter_y_ptr.get()};
401 tc::InferResult* results;
402 auto start = std::chrono::high_resolution_clock::now();
404 client->Infer(&results, options, inputs, outputs, http_headers, compression_algorithm),
405 "unable to run model");
406 auto end = std::chrono::high_resolution_clock::now();
407 std::chrono::duration<double> elapsed = end - start;
408 std::cout <<
"Time taken for inference: " << elapsed.count() <<
" seconds" << std::endl;
409 std::shared_ptr<tc::InferResult> results_ptr;
410 results_ptr.reset(results);
415 vector<NuGraphOutput> infer_output;
416 vector<string> outnames = {
417 "x_semantic_u",
"x_semantic_v",
"x_semantic_y",
"x_filter_u",
"x_filter_v",
"x_filter_y"};
418 for (
const auto& name : outnames) {
421 FAIL_IF_ERR(results_ptr->RawData(name, (
const uint8_t**)&_data, &_byte_size),
422 "unable to get result data for " + name);
423 size_t n_elements = _byte_size /
sizeof(float);
424 std::vector<float> out_data(_data, _data + n_elements);
std::string ssl_certificate_chain
#define FAIL_IF_ERR(X, MSG)
std::string model_version
decltype(auto) constexpr end(T &&obj)
ADL-aware version of std::end.
std::vector< std::unique_ptr< DecoderToolBase > > _decoderToolsVec
std::string inference_url
std::unique_ptr< LoaderToolBase > _loaderTool
std::string ssl_root_certificates
std::string inference_model_name
std::string ssl_private_key