LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
TritonClient.cc
Go to the documentation of this file.
3 
4 #include "cetlib_except/exception.h"
7 
8 #include "grpc_client.h"
9 
10 #include <chrono>
11 #include <cmath>
12 #include <exception>
13 #include <sstream>
14 #include <string>
15 #include <tuple>
16 #include <utility>
17 
18 namespace nic = triton::client;
19 
20 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc
21 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc
22 
23 namespace lartriton {
24 
26  : allowedTries_(params.get<unsigned>("allowedTries", 0))
27  , serverURL_(params.get<std::string>("serverURL"))
28  , verbose_(params.get<bool>("verbose"))
29  , ssl_(params.get<bool>("ssl", false))
30  , sslRootCertificates_(params.get<std::string>("sslRootCertificates", ""))
31  , sslPrivateKey_(params.get<std::string>("sslPrivateKey", ""))
32  , sslCertificateChain_(params.get<std::string>("sslCertificateChain", ""))
33  , options_(params.get<std::string>("modelName"))
34  {
35  //get appropriate server for this model
36  if (verbose_) MF_LOG_INFO("TritonClient") << "Using server: " << serverURL_;
37 
38  //connect to the server
39  if (ssl_) {
40  nic::SslOptions ssl_options = nic::SslOptions();
41  ssl_options.root_certificates = sslRootCertificates_;
42  ssl_options.private_key = sslPrivateKey_;
43  ssl_options.certificate_chain = sslCertificateChain_;
45  nic::InferenceServerGrpcClient::Create(
46  &client_, serverURL_, verbose_, true, ssl_options, nic::KeepAliveOptions(), true),
47  "TritonClient(): unable to create inference context");
48  }
49  else {
51  nic::InferenceServerGrpcClient::Create(&client_, serverURL_, verbose_, false),
52  "TritonClient(): unable to create inference context");
53  }
54 
55  //set options
56  options_.model_version_ = params.get<std::string>("modelVersion");
57  //convert seconds to microseconds
58  options_.client_timeout_ = params.get<unsigned>("timeout") * 1e6;
59 
60  //config needed for batch size
61  inference::ModelConfigResponse modelConfigResponse;
63  client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
64  "TritonClient(): unable to get model config");
65  inference::ModelConfig modelConfig(modelConfigResponse.config());
66 
67  //check batch size limitations (after i/o setup)
68  //triton uses max batch size = 0 to denote a model that does not support batching
69  //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present
70  //so set the local max to 1 and keep track of "no batch" case
71  maxBatchSize_ = modelConfig.max_batch_size();
72  noBatch_ = maxBatchSize_ == 0;
73  maxBatchSize_ = std::max(1u, maxBatchSize_);
74 
75  //get model info
76  inference::ModelMetadataResponse modelMetadata;
78  client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
79  "TritonClient(): unable to get model metadata");
80 
81  //get input and output (which know their sizes)
82  const auto& nicInputs = modelMetadata.inputs();
83  const auto& nicOutputs = modelMetadata.outputs();
84 
85  //report all model errors at once
86  std::ostringstream msg;
87  std::string msg_str;
88 
89  //currently no use case is foreseen for a model with zero inputs or outputs
90  if (nicInputs.empty()) msg << "Model on server appears malformed (zero inputs)\n";
91 
92  if (nicOutputs.empty()) msg << "Model on server appears malformed (zero outputs)\n";
93 
94  //stop if errors
95  msg_str = msg.str();
96  if (!msg_str.empty()) throw cet::exception("ModelErrors") << msg_str;
97 
98  //setup input map
99  std::ostringstream io_msg;
100  if (verbose_)
101  io_msg << "Model inputs: "
102  << "\n";
103  inputsTriton_.reserve(nicInputs.size());
104  for (const auto& nicInput : nicInputs) {
105  const auto& iname = nicInput.name();
106  auto [curr_itr, success] = input_.try_emplace(iname, iname, nicInput, noBatch_);
107  auto& curr_input = curr_itr->second;
108  inputsTriton_.push_back(curr_input.data());
109  if (verbose_) {
110  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
111  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
112  }
113  }
114 
115  //allow selecting only some outputs from server
116  const auto& v_outputs = params.get<std::vector<std::string>>("outputs");
117  std::unordered_set<std::string> s_outputs(v_outputs.begin(), v_outputs.end());
118 
119  //setup output map
120  if (verbose_)
121  io_msg << "Model outputs: "
122  << "\n";
123  outputsTriton_.reserve(nicOutputs.size());
124  for (const auto& nicOutput : nicOutputs) {
125  const auto& oname = nicOutput.name();
126  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end()) continue;
127  auto [curr_itr, success] = output_.try_emplace(oname, oname, nicOutput, noBatch_);
128  auto& curr_output = curr_itr->second;
129  outputsTriton_.push_back(curr_output.data());
130  if (verbose_) {
131  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
132  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
133  }
134  if (!s_outputs.empty()) s_outputs.erase(oname);
135  }
136 
137  //check if any requested outputs were not available
138  if (!s_outputs.empty())
139  throw cet::exception("MissingOutput")
140  << "Some requested outputs were not available on the server: "
141  << triton_utils::printColl(s_outputs);
142 
143  //propagate batch size to inputs and outputs
144  setBatchSize(1);
145 
146  //print model info
147  if (verbose_) {
148  std::ostringstream model_msg;
149  model_msg << "Model name: " << options_.model_name_ << "\n"
150  << "Model version: " << options_.model_version_ << "\n"
151  << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
152  MF_LOG_INFO("TritonClient") << model_msg.str() << io_msg.str();
153  }
154  }
155 
156  bool TritonClient::setBatchSize(unsigned bsize)
157  {
158  if (bsize > maxBatchSize_) {
159  MF_LOG_WARNING("TritonClient")
160  << "Requested batch size " << bsize << " exceeds server-specified max batch size "
161  << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
162  return false;
163  }
164  batchSize_ = bsize;
165  //set for input and output
166  for (auto& element : input_) {
167  element.second.setBatchSize(bsize);
168  }
169  for (auto& element : output_) {
170  element.second.setBatchSize(bsize);
171  }
172  return true;
173  }
174 
176  {
177  for (auto& element : input_) {
178  element.second.reset();
179  }
180  for (auto& element : output_) {
181  element.second.reset();
182  }
183  }
184 
185  bool TritonClient::getResults(std::shared_ptr<nic::InferResult> results)
186  {
187  for (auto& [oname, output] : output_) {
188  //set shape here before output becomes const
189  if (output.variableDims()) {
190  std::vector<int64_t> tmp_shape;
191  bool status =
192  triton_utils::warnIfError(results->Shape(oname, &tmp_shape),
193  "getResults(): unable to get output shape for " + oname);
194  if (!status) return status;
195  output.setShape(tmp_shape, false);
196  }
197  //extend lifetime
198  output.setResult(results);
199  }
200 
201  return true;
202  }
203 
205  {
206  tries_ = 0;
207  }
208 
209  //default case for sync and pseudo async
211  {
212  //in case there is nothing to process
213  if (batchSize_ == 0) {
214  finish(true);
215  return;
216  }
217 
218  // Get the status of the server prior to the request being made.
219  const auto& start_status = getServerSideStatus();
220 
221  //blocking call
222  auto t1 = std::chrono::steady_clock::now();
223  nic::InferResult* results;
224 
225  nic::Headers http_headers;
226  grpc_compression_algorithm compression_algorithm =
227  grpc_compression_algorithm::GRPC_COMPRESS_NONE;
228 
229  bool status = triton_utils::warnIfError(
230  client_->Infer(
231  &results, options_, inputsTriton_, outputsTriton_, http_headers, compression_algorithm),
232  "evaluate(): unable to run and/or get result");
233  if (!status) {
234  finish(false);
235  return;
236  }
237 
238  auto t2 = std::chrono::steady_clock::now();
239  MF_LOG_DEBUG("TritonClient")
240  << "Remote time: " << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
241 
242  const auto& end_status = getServerSideStatus();
243 
244  if (verbose()) {
245  const auto& stats = summarizeServerStats(start_status, end_status);
246  reportServerSideStats(stats);
247  }
248 
249  std::shared_ptr<nic::InferResult> results_ptr(results);
250  status = getResults(results_ptr);
251 
252  finish(status);
253  }
254 
255  void TritonClient::finish(bool success)
256  {
257  //retries are only allowed if no exception was raised
258  if (!success) {
259  ++tries_;
260  //if max retries has not been exceeded, call evaluate again
261  if (tries_ < allowedTries_) {
262  evaluate();
263  //avoid calling doneWaiting() twice
264  return;
265  }
266  //prepare an exception if exceeded
267  throw cet::exception("TritonClient")
268  << "call failed after max " << tries_ << " tries" << std::endl;
269  }
270  }
271 
273  {
274  std::ostringstream msg;
275 
276  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
277  const uint64_t count = stats.success_count_;
278  msg << " Inference count: " << stats.inference_count_ << "\n";
279  msg << " Execution count: " << stats.execution_count_ << "\n";
280  msg << " Successful request count: " << count << "\n";
281 
282  if (count > 0) {
283  auto get_avg_us = [count](uint64_t tval) {
284  constexpr uint64_t us_to_ns = 1000;
285  return tval / us_to_ns / count;
286  };
287 
288  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
289  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
290  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
291  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
292  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
293  const uint64_t compute_avg_us =
294  compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
295  const uint64_t overhead = (cumm_avg_us > queue_avg_us + compute_avg_us) ?
296  (cumm_avg_us - queue_avg_us - compute_avg_us) :
297  0;
298 
299  msg << " Avg request latency: " << cumm_avg_us << " usec"
300  << "\n"
301  << " (overhead " << overhead << " usec + "
302  << "queue " << queue_avg_us << " usec + "
303  << "compute input " << compute_input_avg_us << " usec + "
304  << "compute infer " << compute_infer_avg_us << " usec + "
305  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
306  }
307 
308  MF_LOG_DEBUG("TritonClient") << msg.str();
309  }
310 
312  const inference::ModelStatistics& start_status,
313  const inference::ModelStatistics& end_status) const
314  {
315  TritonClient::ServerSideStats server_stats;
316 
317  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
318  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
319  server_stats.success_count_ = end_status.inference_stats().success().count() -
320  start_status.inference_stats().success().count();
321  server_stats.cumm_time_ns_ =
322  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
323  server_stats.queue_time_ns_ =
324  end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
325  server_stats.compute_input_time_ns_ = end_status.inference_stats().compute_input().ns() -
326  start_status.inference_stats().compute_input().ns();
327  server_stats.compute_infer_time_ns_ = end_status.inference_stats().compute_infer().ns() -
328  start_status.inference_stats().compute_infer().ns();
329  server_stats.compute_output_time_ns_ = end_status.inference_stats().compute_output().ns() -
330  start_status.inference_stats().compute_output().ns();
331 
332  return server_stats;
333  }
334 
335  inference::ModelStatistics TritonClient::getServerSideStatus() const
336  {
337  if (verbose_) {
338  inference::ModelStatisticsResponse resp;
339  bool success = triton_utils::warnIfError(
340  client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
341  "getServerSideStatus(): unable to get model statistics");
342  if (success) return *(resp.model_stats().begin());
343  }
344  return inference::ModelStatistics{};
345  }
346 
347 }
nic::InferOptions options_
Definition: TritonClient.h:87
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TTree * t1
Definition: plottest35.C:26
std::string printColl(const C &coll, const std::string &delim)
Definition: triton_utils.cc:15
void reportServerSideStats(const ServerSideStats &stats) const
const TritonOutputMap & output() const
Definition: TritonClient.h:38
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:31
STL namespace.
microsecond microseconds
Alias for common language habits.
Definition: spacetime.h:119
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:26
void finish(bool success)
std::vector< const nic::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:83
std::string sslPrivateKey_
Definition: TritonClient.h:78
TritonOutputMap output_
Definition: TritonClient.h:69
std::string sslRootCertificates_
Definition: TritonClient.h:77
inference::ModelStatistics getServerSideStatus() const
std::unique_ptr< nic::InferenceServerGrpcClient > client_
Definition: TritonClient.h:85
T get(std::string const &key) const
Definition: ParameterSet.h:314
TritonInputMap input_
Definition: TritonClient.h:68
#define MF_LOG_INFO(category)
bool setBatchSize(unsigned bsize)
TTree * t2
Definition: plottest35.C:36
decltype(auto) get(T &&obj)
ADL-aware version of std::to_string.
Definition: StdUtils.h:120
bool verbose() const
Definition: TritonClient.h:40
std::vector< nic::InferInput * > inputsTriton_
Definition: TritonClient.h:82
#define MF_LOG_DEBUG(id)
TritonClient(const fhicl::ParameterSet &params)
Definition: TritonClient.cc:25
bool getResults(std::shared_ptr< nic::InferResult > results)
#define MF_LOG_WARNING(category)
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
std::string sslCertificateChain_
Definition: TritonClient.h:79