LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
lartriton::TritonClient Class Reference

#include "TritonClient.h"

Classes

struct  ServerSideStats
 

Public Member Functions

 TritonClient (const fhicl::ParameterSet &params)
 
TritonInputMapinput ()
 
const TritonOutputMapoutput () const
 
unsigned batchSize () const
 
bool verbose () const
 
bool setBatchSize (unsigned bsize)
 
void dispatch ()
 
void reset ()
 

Protected Member Functions

bool getResults (std::shared_ptr< nic::InferResult > results)
 
void start ()
 
void evaluate ()
 
void finish (bool success)
 
void reportServerSideStats (const ServerSideStats &stats) const
 
ServerSideStats summarizeServerStats (const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
 
inference::ModelStatistics getServerSideStatus () const
 

Protected Attributes

TritonInputMap input_
 
TritonOutputMap output_
 
unsigned allowedTries_
 
unsigned tries_
 
std::string serverURL_
 
unsigned maxBatchSize_
 
unsigned batchSize_
 
bool noBatch_
 
bool verbose_
 
bool ssl_
 
std::string sslRootCertificates_
 
std::string sslPrivateKey_
 
std::string sslCertificateChain_
 
std::vector< nic::InferInput * > inputsTriton_
 
std::vector< const nic::InferRequestedOutput * > outputsTriton_
 
std::unique_ptr< nic::InferenceServerGrpcClient > client_
 
nic::InferOptions options_
 

Detailed Description

Definition at line 20 of file TritonClient.h.

Constructor & Destructor Documentation

lartriton::TritonClient::TritonClient ( const fhicl::ParameterSet params)

Definition at line 25 of file TritonClient.cc.

References client_, fhicl::ParameterSet::get(), input_, inputsTriton_, maxBatchSize_, MF_LOG_INFO, noBatch_, options_, output_, outputsTriton_, triton_utils::printColl(), serverURL_, setBatchSize(), ssl_, sslCertificateChain_, sslPrivateKey_, sslRootCertificates_, triton_utils::throwIfError(), and verbose_.

26  : allowedTries_(params.get<unsigned>("allowedTries", 0))
27  , serverURL_(params.get<std::string>("serverURL"))
28  , verbose_(params.get<bool>("verbose"))
29  , ssl_(params.get<bool>("ssl", false))
30  , sslRootCertificates_(params.get<std::string>("sslRootCertificates", ""))
31  , sslPrivateKey_(params.get<std::string>("sslPrivateKey", ""))
32  , sslCertificateChain_(params.get<std::string>("sslCertificateChain", ""))
33  , options_(params.get<std::string>("modelName"))
34  {
35  //get appropriate server for this model
36  if (verbose_) MF_LOG_INFO("TritonClient") << "Using server: " << serverURL_;
37 
38  //connect to the server
39  if (ssl_) {
40  nic::SslOptions ssl_options = nic::SslOptions();
41  ssl_options.root_certificates = sslRootCertificates_;
42  ssl_options.private_key = sslPrivateKey_;
43  ssl_options.certificate_chain = sslCertificateChain_;
45  nic::InferenceServerGrpcClient::Create(
46  &client_, serverURL_, verbose_, true, ssl_options, nic::KeepAliveOptions(), true),
47  "TritonClient(): unable to create inference context");
48  }
49  else {
51  nic::InferenceServerGrpcClient::Create(&client_, serverURL_, verbose_, false),
52  "TritonClient(): unable to create inference context");
53  }
54 
55  //set options
56  options_.model_version_ = params.get<std::string>("modelVersion");
57  //convert seconds to microseconds
58  options_.client_timeout_ = params.get<unsigned>("timeout") * 1e6;
59 
60  //config needed for batch size
61  inference::ModelConfigResponse modelConfigResponse;
63  client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
64  "TritonClient(): unable to get model config");
65  inference::ModelConfig modelConfig(modelConfigResponse.config());
66 
67  //check batch size limitations (after i/o setup)
68  //triton uses max batch size = 0 to denote a model that does not support batching
69  //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present
70  //so set the local max to 1 and keep track of "no batch" case
71  maxBatchSize_ = modelConfig.max_batch_size();
72  noBatch_ = maxBatchSize_ == 0;
73  maxBatchSize_ = std::max(1u, maxBatchSize_);
74 
75  //get model info
76  inference::ModelMetadataResponse modelMetadata;
78  client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
79  "TritonClient(): unable to get model metadata");
80 
81  //get input and output (which know their sizes)
82  const auto& nicInputs = modelMetadata.inputs();
83  const auto& nicOutputs = modelMetadata.outputs();
84 
85  //report all model errors at once
86  std::ostringstream msg;
87  std::string msg_str;
88 
89  //currently no use case is foreseen for a model with zero inputs or outputs
90  if (nicInputs.empty()) msg << "Model on server appears malformed (zero inputs)\n";
91 
92  if (nicOutputs.empty()) msg << "Model on server appears malformed (zero outputs)\n";
93 
94  //stop if errors
95  msg_str = msg.str();
96  if (!msg_str.empty()) throw cet::exception("ModelErrors") << msg_str;
97 
98  //setup input map
99  std::ostringstream io_msg;
100  if (verbose_)
101  io_msg << "Model inputs: "
102  << "\n";
103  inputsTriton_.reserve(nicInputs.size());
104  for (const auto& nicInput : nicInputs) {
105  const auto& iname = nicInput.name();
106  auto [curr_itr, success] = input_.try_emplace(iname, iname, nicInput, noBatch_);
107  auto& curr_input = curr_itr->second;
108  inputsTriton_.push_back(curr_input.data());
109  if (verbose_) {
110  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
111  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
112  }
113  }
114 
115  //allow selecting only some outputs from server
116  const auto& v_outputs = params.get<std::vector<std::string>>("outputs");
117  std::unordered_set<std::string> s_outputs(v_outputs.begin(), v_outputs.end());
118 
119  //setup output map
120  if (verbose_)
121  io_msg << "Model outputs: "
122  << "\n";
123  outputsTriton_.reserve(nicOutputs.size());
124  for (const auto& nicOutput : nicOutputs) {
125  const auto& oname = nicOutput.name();
126  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end()) continue;
127  auto [curr_itr, success] = output_.try_emplace(oname, oname, nicOutput, noBatch_);
128  auto& curr_output = curr_itr->second;
129  outputsTriton_.push_back(curr_output.data());
130  if (verbose_) {
131  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
132  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
133  }
134  if (!s_outputs.empty()) s_outputs.erase(oname);
135  }
136 
137  //check if any requested outputs were not available
138  if (!s_outputs.empty())
139  throw cet::exception("MissingOutput")
140  << "Some requested outputs were not available on the server: "
141  << triton_utils::printColl(s_outputs);
142 
143  //propagate batch size to inputs and outputs
144  setBatchSize(1);
145 
146  //print model info
147  if (verbose_) {
148  std::ostringstream model_msg;
149  model_msg << "Model name: " << options_.model_name_ << "\n"
150  << "Model version: " << options_.model_version_ << "\n"
151  << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
152  MF_LOG_INFO("TritonClient") << model_msg.str() << io_msg.str();
153  }
154  }
nic::InferOptions options_
Definition: TritonClient.h:87
std::string printColl(const C &coll, const std::string &delim)
Definition: triton_utils.cc:15
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:26
std::vector< const nic::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:83
std::string sslPrivateKey_
Definition: TritonClient.h:78
TritonOutputMap output_
Definition: TritonClient.h:69
std::string sslRootCertificates_
Definition: TritonClient.h:77
std::unique_ptr< nic::InferenceServerGrpcClient > client_
Definition: TritonClient.h:85
T get(std::string const &key) const
Definition: ParameterSet.h:314
TritonInputMap input_
Definition: TritonClient.h:68
#define MF_LOG_INFO(category)
bool setBatchSize(unsigned bsize)
std::vector< nic::InferInput * > inputsTriton_
Definition: TritonClient.h:82
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
std::string sslCertificateChain_
Definition: TritonClient.h:79

Member Function Documentation

unsigned lartriton::TritonClient::batchSize ( ) const
inline

Definition at line 39 of file TritonClient.h.

39 { return batchSize_; }
void lartriton::TritonClient::dispatch ( )
inline

Definition at line 44 of file TritonClient.h.

45  {
46  start();
47  evaluate();
48  }
void lartriton::TritonClient::evaluate ( )
protected

Definition at line 210 of file TritonClient.cc.

References batchSize_, client_, finish(), getResults(), getServerSideStatus(), inputsTriton_, MF_LOG_DEBUG, options_, outputsTriton_, reportServerSideStats(), summarizeServerStats(), t1, t2, verbose(), and triton_utils::warnIfError().

Referenced by finish().

211  {
212  //in case there is nothing to process
213  if (batchSize_ == 0) {
214  finish(true);
215  return;
216  }
217 
218  // Get the status of the server prior to the request being made.
219  const auto& start_status = getServerSideStatus();
220 
221  //blocking call
222  auto t1 = std::chrono::steady_clock::now();
223  nic::InferResult* results;
224 
225  nic::Headers http_headers;
226  grpc_compression_algorithm compression_algorithm =
227  grpc_compression_algorithm::GRPC_COMPRESS_NONE;
228 
229  bool status = triton_utils::warnIfError(
230  client_->Infer(
231  &results, options_, inputsTriton_, outputsTriton_, http_headers, compression_algorithm),
232  "evaluate(): unable to run and/or get result");
233  if (!status) {
234  finish(false);
235  return;
236  }
237 
238  auto t2 = std::chrono::steady_clock::now();
239  MF_LOG_DEBUG("TritonClient")
240  << "Remote time: " << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
241 
242  const auto& end_status = getServerSideStatus();
243 
244  if (verbose()) {
245  const auto& stats = summarizeServerStats(start_status, end_status);
246  reportServerSideStats(stats);
247  }
248 
249  std::shared_ptr<nic::InferResult> results_ptr(results);
250  status = getResults(results_ptr);
251 
252  finish(status);
253  }
nic::InferOptions options_
Definition: TritonClient.h:87
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TTree * t1
Definition: plottest35.C:26
void reportServerSideStats(const ServerSideStats &stats) const
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:31
microsecond microseconds
Alias for common language habits.
Definition: spacetime.h:119
void finish(bool success)
std::vector< const nic::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:83
inference::ModelStatistics getServerSideStatus() const
std::unique_ptr< nic::InferenceServerGrpcClient > client_
Definition: TritonClient.h:85
TTree * t2
Definition: plottest35.C:36
bool verbose() const
Definition: TritonClient.h:40
std::vector< nic::InferInput * > inputsTriton_
Definition: TritonClient.h:82
#define MF_LOG_DEBUG(id)
bool getResults(std::shared_ptr< nic::InferResult > results)
void lartriton::TritonClient::finish ( bool  success)
protected

Definition at line 255 of file TritonClient.cc.

References allowedTries_, evaluate(), and tries_.

Referenced by evaluate().

256  {
257  //retries are only allowed if no exception was raised
258  if (!success) {
259  ++tries_;
260  //if max retries has not been exceeded, call evaluate again
261  if (tries_ < allowedTries_) {
262  evaluate();
263  //avoid calling doneWaiting() twice
264  return;
265  }
266  //prepare an exception if exceeded
267  throw cet::exception("TritonClient")
268  << "call failed after max " << tries_ << " tries" << std::endl;
269  }
270  }
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
bool lartriton::TritonClient::getResults ( std::shared_ptr< nic::InferResult >  results)
protected

Definition at line 185 of file TritonClient.cc.

References output(), output_, and triton_utils::warnIfError().

Referenced by evaluate().

186  {
187  for (auto& [oname, output] : output_) {
188  //set shape here before output becomes const
189  if (output.variableDims()) {
190  std::vector<int64_t> tmp_shape;
191  bool status =
192  triton_utils::warnIfError(results->Shape(oname, &tmp_shape),
193  "getResults(): unable to get output shape for " + oname);
194  if (!status) return status;
195  output.setShape(tmp_shape, false);
196  }
197  //extend lifetime
198  output.setResult(results);
199  }
200 
201  return true;
202  }
const TritonOutputMap & output() const
Definition: TritonClient.h:38
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:31
TritonOutputMap output_
Definition: TritonClient.h:69
inference::ModelStatistics lartriton::TritonClient::getServerSideStatus ( ) const
protected

Definition at line 335 of file TritonClient.cc.

References client_, options_, verbose_, and triton_utils::warnIfError().

Referenced by evaluate().

336  {
337  if (verbose_) {
338  inference::ModelStatisticsResponse resp;
339  bool success = triton_utils::warnIfError(
340  client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
341  "getServerSideStatus(): unable to get model statistics");
342  if (success) return *(resp.model_stats().begin());
343  }
344  return inference::ModelStatistics{};
345  }
nic::InferOptions options_
Definition: TritonClient.h:87
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:31
std::unique_ptr< nic::InferenceServerGrpcClient > client_
Definition: TritonClient.h:85
TritonInputMap& lartriton::TritonClient::input ( )
inline

Definition at line 37 of file TritonClient.h.

37 { return input_; }
TritonInputMap input_
Definition: TritonClient.h:68
const TritonOutputMap& lartriton::TritonClient::output ( ) const
inline

Definition at line 38 of file TritonClient.h.

Referenced by getResults().

38 { return output_; }
TritonOutputMap output_
Definition: TritonClient.h:69
void lartriton::TritonClient::reportServerSideStats ( const ServerSideStats stats) const
protected

Definition at line 272 of file TritonClient.cc.

References lartriton::TritonClient::ServerSideStats::compute_infer_time_ns_, lartriton::TritonClient::ServerSideStats::compute_input_time_ns_, lartriton::TritonClient::ServerSideStats::compute_output_time_ns_, lartriton::TritonClient::ServerSideStats::cumm_time_ns_, lartriton::TritonClient::ServerSideStats::execution_count_, lartriton::TritonClient::ServerSideStats::inference_count_, MF_LOG_DEBUG, lartriton::TritonClient::ServerSideStats::queue_time_ns_, and lartriton::TritonClient::ServerSideStats::success_count_.

Referenced by evaluate().

273  {
274  std::ostringstream msg;
275 
276  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
277  const uint64_t count = stats.success_count_;
278  msg << " Inference count: " << stats.inference_count_ << "\n";
279  msg << " Execution count: " << stats.execution_count_ << "\n";
280  msg << " Successful request count: " << count << "\n";
281 
282  if (count > 0) {
283  auto get_avg_us = [count](uint64_t tval) {
284  constexpr uint64_t us_to_ns = 1000;
285  return tval / us_to_ns / count;
286  };
287 
288  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
289  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
290  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
291  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
292  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
293  const uint64_t compute_avg_us =
294  compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
295  const uint64_t overhead = (cumm_avg_us > queue_avg_us + compute_avg_us) ?
296  (cumm_avg_us - queue_avg_us - compute_avg_us) :
297  0;
298 
299  msg << " Avg request latency: " << cumm_avg_us << " usec"
300  << "\n"
301  << " (overhead " << overhead << " usec + "
302  << "queue " << queue_avg_us << " usec + "
303  << "compute input " << compute_input_avg_us << " usec + "
304  << "compute infer " << compute_infer_avg_us << " usec + "
305  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
306  }
307 
308  MF_LOG_DEBUG("TritonClient") << msg.str();
309  }
#define MF_LOG_DEBUG(id)
void lartriton::TritonClient::reset ( )

Definition at line 175 of file TritonClient.cc.

References input_, and output_.

176  {
177  for (auto& element : input_) {
178  element.second.reset();
179  }
180  for (auto& element : output_) {
181  element.second.reset();
182  }
183  }
TritonOutputMap output_
Definition: TritonClient.h:69
TritonInputMap input_
Definition: TritonClient.h:68
bool lartriton::TritonClient::setBatchSize ( unsigned  bsize)

Definition at line 156 of file TritonClient.cc.

References batchSize_, input_, maxBatchSize_, MF_LOG_WARNING, and output_.

Referenced by TritonClient().

157  {
158  if (bsize > maxBatchSize_) {
159  MF_LOG_WARNING("TritonClient")
160  << "Requested batch size " << bsize << " exceeds server-specified max batch size "
161  << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
162  return false;
163  }
164  batchSize_ = bsize;
165  //set for input and output
166  for (auto& element : input_) {
167  element.second.setBatchSize(bsize);
168  }
169  for (auto& element : output_) {
170  element.second.setBatchSize(bsize);
171  }
172  return true;
173  }
TritonOutputMap output_
Definition: TritonClient.h:69
TritonInputMap input_
Definition: TritonClient.h:68
#define MF_LOG_WARNING(category)
void lartriton::TritonClient::start ( )
protected

Definition at line 204 of file TritonClient.cc.

References tries_.

205  {
206  tries_ = 0;
207  }
TritonClient::ServerSideStats lartriton::TritonClient::summarizeServerStats ( const inference::ModelStatistics &  start_status,
const inference::ModelStatistics &  end_status 
) const
protected

Definition at line 311 of file TritonClient.cc.

References lartriton::TritonClient::ServerSideStats::compute_infer_time_ns_, lartriton::TritonClient::ServerSideStats::compute_input_time_ns_, lartriton::TritonClient::ServerSideStats::compute_output_time_ns_, lartriton::TritonClient::ServerSideStats::cumm_time_ns_, lartriton::TritonClient::ServerSideStats::execution_count_, lartriton::TritonClient::ServerSideStats::inference_count_, lartriton::TritonClient::ServerSideStats::queue_time_ns_, and lartriton::TritonClient::ServerSideStats::success_count_.

Referenced by evaluate().

314  {
315  TritonClient::ServerSideStats server_stats;
316 
317  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
318  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
319  server_stats.success_count_ = end_status.inference_stats().success().count() -
320  start_status.inference_stats().success().count();
321  server_stats.cumm_time_ns_ =
322  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
323  server_stats.queue_time_ns_ =
324  end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
325  server_stats.compute_input_time_ns_ = end_status.inference_stats().compute_input().ns() -
326  start_status.inference_stats().compute_input().ns();
327  server_stats.compute_infer_time_ns_ = end_status.inference_stats().compute_infer().ns() -
328  start_status.inference_stats().compute_infer().ns();
329  server_stats.compute_output_time_ns_ = end_status.inference_stats().compute_output().ns() -
330  start_status.inference_stats().compute_output().ns();
331 
332  return server_stats;
333  }
bool lartriton::TritonClient::verbose ( ) const
inline

Definition at line 40 of file TritonClient.h.

Referenced by evaluate().

40 { return verbose_; }

Member Data Documentation

unsigned lartriton::TritonClient::allowedTries_
protected

Definition at line 70 of file TritonClient.h.

Referenced by finish().

unsigned lartriton::TritonClient::batchSize_
protected

Definition at line 73 of file TritonClient.h.

Referenced by evaluate(), and setBatchSize().

std::unique_ptr<nic::InferenceServerGrpcClient> lartriton::TritonClient::client_
protected

Definition at line 85 of file TritonClient.h.

Referenced by evaluate(), getServerSideStatus(), and TritonClient().

TritonInputMap lartriton::TritonClient::input_
protected

Definition at line 68 of file TritonClient.h.

Referenced by reset(), setBatchSize(), and TritonClient().

std::vector<nic::InferInput*> lartriton::TritonClient::inputsTriton_
protected

Definition at line 82 of file TritonClient.h.

Referenced by evaluate(), and TritonClient().

unsigned lartriton::TritonClient::maxBatchSize_
protected

Definition at line 72 of file TritonClient.h.

Referenced by setBatchSize(), and TritonClient().

bool lartriton::TritonClient::noBatch_
protected

Definition at line 74 of file TritonClient.h.

Referenced by TritonClient().

nic::InferOptions lartriton::TritonClient::options_
protected

Definition at line 87 of file TritonClient.h.

Referenced by evaluate(), getServerSideStatus(), and TritonClient().

TritonOutputMap lartriton::TritonClient::output_
protected

Definition at line 69 of file TritonClient.h.

Referenced by getResults(), reset(), setBatchSize(), and TritonClient().

std::vector<const nic::InferRequestedOutput*> lartriton::TritonClient::outputsTriton_
protected

Definition at line 83 of file TritonClient.h.

Referenced by evaluate(), and TritonClient().

std::string lartriton::TritonClient::serverURL_
protected

Definition at line 71 of file TritonClient.h.

Referenced by TritonClient().

bool lartriton::TritonClient::ssl_
protected

Definition at line 76 of file TritonClient.h.

Referenced by TritonClient().

std::string lartriton::TritonClient::sslCertificateChain_
protected

Definition at line 79 of file TritonClient.h.

Referenced by TritonClient().

std::string lartriton::TritonClient::sslPrivateKey_
protected

Definition at line 78 of file TritonClient.h.

Referenced by TritonClient().

std::string lartriton::TritonClient::sslRootCertificates_
protected

Definition at line 77 of file TritonClient.h.

Referenced by TritonClient().

unsigned lartriton::TritonClient::tries_
protected

Definition at line 70 of file TritonClient.h.

Referenced by finish(), and start().

bool lartriton::TritonClient::verbose_
protected

Definition at line 75 of file TritonClient.h.

Referenced by getServerSideStatus(), and TritonClient().


The documentation for this class was generated from the following files: