LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
TritonClient.cc
Go to the documentation of this file.
3 
4 #include "cetlib_except/exception.h"
7 
8 #include "grpc_client.h"
9 
10 #include <chrono>
11 #include <cmath>
12 #include <exception>
13 #include <sstream>
14 #include <string>
15 #include <tuple>
16 #include <utility>
17 
18 namespace nic = triton::client;
19 
20 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc
21 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc
22 
23 namespace lartriton {
24 
26  : allowedTries_(params.get<unsigned>("allowedTries", 0))
27  , serverURL_(params.get<std::string>("serverURL"))
28  , verbose_(params.get<bool>("verbose"))
29  , options_(params.get<std::string>("modelName"))
30  {
31  //get appropriate server for this model
32  if (verbose_) MF_LOG_INFO("TritonClient") << "Using server: " << serverURL_;
33 
34  //connect to the server
35  //TODO: add SSL options
36  triton_utils::throwIfError(nic::InferenceServerGrpcClient::Create(&client_, serverURL_, false),
37  "TritonClient(): unable to create inference context");
38 
39  //set options
40  options_.model_version_ = params.get<std::string>("modelVersion");
41  //convert seconds to microseconds
42  options_.client_timeout_ = params.get<unsigned>("timeout") * 1e6;
43 
44  //config needed for batch size
45  inference::ModelConfigResponse modelConfigResponse;
47  client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
48  "TritonClient(): unable to get model config");
49  inference::ModelConfig modelConfig(modelConfigResponse.config());
50 
51  //check batch size limitations (after i/o setup)
52  //triton uses max batch size = 0 to denote a model that does not support batching
53  //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present
54  //so set the local max to 1 and keep track of "no batch" case
55  maxBatchSize_ = modelConfig.max_batch_size();
56  noBatch_ = maxBatchSize_ == 0;
57  maxBatchSize_ = std::max(1u, maxBatchSize_);
58 
59  //get model info
60  inference::ModelMetadataResponse modelMetadata;
62  client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
63  "TritonClient(): unable to get model metadata");
64 
65  //get input and output (which know their sizes)
66  const auto& nicInputs = modelMetadata.inputs();
67  const auto& nicOutputs = modelMetadata.outputs();
68 
69  //report all model errors at once
70  std::ostringstream msg;
71  std::string msg_str;
72 
73  //currently no use case is foreseen for a model with zero inputs or outputs
74  if (nicInputs.empty()) msg << "Model on server appears malformed (zero inputs)\n";
75 
76  if (nicOutputs.empty()) msg << "Model on server appears malformed (zero outputs)\n";
77 
78  //stop if errors
79  msg_str = msg.str();
80  if (!msg_str.empty()) throw cet::exception("ModelErrors") << msg_str;
81 
82  //setup input map
83  std::ostringstream io_msg;
84  if (verbose_)
85  io_msg << "Model inputs: "
86  << "\n";
87  inputsTriton_.reserve(nicInputs.size());
88  for (const auto& nicInput : nicInputs) {
89  const auto& iname = nicInput.name();
90  auto [curr_itr, success] = input_.try_emplace(iname, iname, nicInput, noBatch_);
91  auto& curr_input = curr_itr->second;
92  inputsTriton_.push_back(curr_input.data());
93  if (verbose_) {
94  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
95  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
96  }
97  }
98 
99  //allow selecting only some outputs from server
100  const auto& v_outputs = params.get<std::vector<std::string>>("outputs");
101  std::unordered_set<std::string> s_outputs(v_outputs.begin(), v_outputs.end());
102 
103  //setup output map
104  if (verbose_)
105  io_msg << "Model outputs: "
106  << "\n";
107  outputsTriton_.reserve(nicOutputs.size());
108  for (const auto& nicOutput : nicOutputs) {
109  const auto& oname = nicOutput.name();
110  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end()) continue;
111  auto [curr_itr, success] = output_.try_emplace(oname, oname, nicOutput, noBatch_);
112  auto& curr_output = curr_itr->second;
113  outputsTriton_.push_back(curr_output.data());
114  if (verbose_) {
115  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
116  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
117  }
118  if (!s_outputs.empty()) s_outputs.erase(oname);
119  }
120 
121  //check if any requested outputs were not available
122  if (!s_outputs.empty())
123  throw cet::exception("MissingOutput")
124  << "Some requested outputs were not available on the server: "
125  << triton_utils::printColl(s_outputs);
126 
127  //propagate batch size to inputs and outputs
128  setBatchSize(1);
129 
130  //print model info
131  if (verbose_) {
132  std::ostringstream model_msg;
133  model_msg << "Model name: " << options_.model_name_ << "\n"
134  << "Model version: " << options_.model_version_ << "\n"
135  << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
136  MF_LOG_INFO("TritonClient") << model_msg.str() << io_msg.str();
137  }
138  }
139 
140  bool TritonClient::setBatchSize(unsigned bsize)
141  {
142  if (bsize > maxBatchSize_) {
143  MF_LOG_WARNING("TritonClient")
144  << "Requested batch size " << bsize << " exceeds server-specified max batch size "
145  << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
146  return false;
147  }
148  batchSize_ = bsize;
149  //set for input and output
150  for (auto& element : input_) {
151  element.second.setBatchSize(bsize);
152  }
153  for (auto& element : output_) {
154  element.second.setBatchSize(bsize);
155  }
156  return true;
157  }
158 
160  {
161  for (auto& element : input_) {
162  element.second.reset();
163  }
164  for (auto& element : output_) {
165  element.second.reset();
166  }
167  }
168 
169  bool TritonClient::getResults(std::shared_ptr<nic::InferResult> results)
170  {
171  for (auto& [oname, output] : output_) {
172  //set shape here before output becomes const
173  if (output.variableDims()) {
174  std::vector<int64_t> tmp_shape;
175  bool status =
176  triton_utils::warnIfError(results->Shape(oname, &tmp_shape),
177  "getResults(): unable to get output shape for " + oname);
178  if (!status) return status;
179  output.setShape(tmp_shape, false);
180  }
181  //extend lifetime
182  output.setResult(results);
183  }
184 
185  return true;
186  }
187 
189  {
190  tries_ = 0;
191  }
192 
193  //default case for sync and pseudo async
195  {
196  //in case there is nothing to process
197  if (batchSize_ == 0) {
198  finish(true);
199  return;
200  }
201 
202  // Get the status of the server prior to the request being made.
203  const auto& start_status = getServerSideStatus();
204 
205  //blocking call
206  auto t1 = std::chrono::steady_clock::now();
207  nic::InferResult* results;
208  bool status =
210  "evaluate(): unable to run and/or get result");
211  if (!status) {
212  finish(false);
213  return;
214  }
215 
216  auto t2 = std::chrono::steady_clock::now();
217  MF_LOG_DEBUG("TritonClient")
218  << "Remote time: " << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
219 
220  const auto& end_status = getServerSideStatus();
221 
222  if (verbose()) {
223  const auto& stats = summarizeServerStats(start_status, end_status);
224  reportServerSideStats(stats);
225  }
226 
227  std::shared_ptr<nic::InferResult> results_ptr(results);
228  status = getResults(results_ptr);
229 
230  finish(status);
231  }
232 
233  void TritonClient::finish(bool success)
234  {
235  //retries are only allowed if no exception was raised
236  if (!success) {
237  ++tries_;
238  //if max retries has not been exceeded, call evaluate again
239  if (tries_ < allowedTries_) {
240  evaluate();
241  //avoid calling doneWaiting() twice
242  return;
243  }
244  //prepare an exception if exceeded
245  throw cet::exception("TritonClient")
246  << "call failed after max " << tries_ << " tries" << std::endl;
247  }
248  }
249 
251  {
252  std::ostringstream msg;
253 
254  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
255  const uint64_t count = stats.success_count_;
256  msg << " Inference count: " << stats.inference_count_ << "\n";
257  msg << " Execution count: " << stats.execution_count_ << "\n";
258  msg << " Successful request count: " << count << "\n";
259 
260  if (count > 0) {
261  auto get_avg_us = [count](uint64_t tval) {
262  constexpr uint64_t us_to_ns = 1000;
263  return tval / us_to_ns / count;
264  };
265 
266  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
267  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
268  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
269  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
270  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
271  const uint64_t compute_avg_us =
272  compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
273  const uint64_t overhead = (cumm_avg_us > queue_avg_us + compute_avg_us) ?
274  (cumm_avg_us - queue_avg_us - compute_avg_us) :
275  0;
276 
277  msg << " Avg request latency: " << cumm_avg_us << " usec"
278  << "\n"
279  << " (overhead " << overhead << " usec + "
280  << "queue " << queue_avg_us << " usec + "
281  << "compute input " << compute_input_avg_us << " usec + "
282  << "compute infer " << compute_infer_avg_us << " usec + "
283  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
284  }
285 
286  MF_LOG_DEBUG("TritonClient") << msg.str();
287  }
288 
290  const inference::ModelStatistics& start_status,
291  const inference::ModelStatistics& end_status) const
292  {
293  TritonClient::ServerSideStats server_stats;
294 
295  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
296  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
297  server_stats.success_count_ = end_status.inference_stats().success().count() -
298  start_status.inference_stats().success().count();
299  server_stats.cumm_time_ns_ =
300  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
301  server_stats.queue_time_ns_ =
302  end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
303  server_stats.compute_input_time_ns_ = end_status.inference_stats().compute_input().ns() -
304  start_status.inference_stats().compute_input().ns();
305  server_stats.compute_infer_time_ns_ = end_status.inference_stats().compute_infer().ns() -
306  start_status.inference_stats().compute_infer().ns();
307  server_stats.compute_output_time_ns_ = end_status.inference_stats().compute_output().ns() -
308  start_status.inference_stats().compute_output().ns();
309 
310  return server_stats;
311  }
312 
313  inference::ModelStatistics TritonClient::getServerSideStatus() const
314  {
315  if (verbose_) {
316  inference::ModelStatisticsResponse resp;
317  bool success = triton_utils::warnIfError(
318  client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
319  "getServerSideStatus(): unable to get model statistics");
320  if (success) return *(resp.model_stats().begin());
321  }
322  return inference::ModelStatistics{};
323  }
324 
325 }
nic::InferOptions options_
Definition: TritonClient.h:83
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TTree * t1
Definition: plottest35.C:26
std::string printColl(const C &coll, const std::string &delim)
Definition: triton_utils.cc:15
void reportServerSideStats(const ServerSideStats &stats) const
const TritonOutputMap & output() const
Definition: TritonClient.h:38
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:31
STL namespace.
microsecond microseconds
Alias for common language habits.
Definition: spacetime.h:119
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:26
void finish(bool success)
std::vector< const nic::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:79
TritonOutputMap output_
Definition: TritonClient.h:69
inference::ModelStatistics getServerSideStatus() const
std::unique_ptr< nic::InferenceServerGrpcClient > client_
Definition: TritonClient.h:81
T get(std::string const &key) const
Definition: ParameterSet.h:314
TritonInputMap input_
Definition: TritonClient.h:68
#define MF_LOG_INFO(category)
bool setBatchSize(unsigned bsize)
TTree * t2
Definition: plottest35.C:36
decltype(auto) get(T &&obj)
ADL-aware version of std::to_string.
Definition: StdUtils.h:120
bool verbose() const
Definition: TritonClient.h:40
std::vector< nic::InferInput * > inputsTriton_
Definition: TritonClient.h:78
#define MF_LOG_DEBUG(id)
TritonClient(const fhicl::ParameterSet &params)
Definition: TritonClient.cc:25
bool getResults(std::shared_ptr< nic::InferResult > results)
#define MF_LOG_WARNING(category)
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33