LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
lartriton::TritonClient Class Reference

#include "TritonClient.h"

Classes

struct  ServerSideStats
 

Public Member Functions

 TritonClient (const fhicl::ParameterSet &params)
 
TritonInputMapinput ()
 
const TritonOutputMapoutput () const
 
unsigned batchSize () const
 
bool verbose () const
 
bool setBatchSize (unsigned bsize)
 
void dispatch ()
 
void reset ()
 

Protected Member Functions

bool getResults (std::shared_ptr< nic::InferResult > results)
 
void start ()
 
void evaluate ()
 
void finish (bool success)
 
void reportServerSideStats (const ServerSideStats &stats) const
 
ServerSideStats summarizeServerStats (const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
 
inference::ModelStatistics getServerSideStatus () const
 

Protected Attributes

TritonInputMap input_
 
TritonOutputMap output_
 
unsigned allowedTries_
 
unsigned tries_
 
std::string serverURL_
 
unsigned maxBatchSize_
 
unsigned batchSize_
 
bool noBatch_
 
bool verbose_
 
std::vector< nic::InferInput * > inputsTriton_
 
std::vector< const nic::InferRequestedOutput * > outputsTriton_
 
std::unique_ptr< nic::InferenceServerGrpcClient > client_
 
nic::InferOptions options_
 

Detailed Description

Definition at line 20 of file TritonClient.h.

Constructor & Destructor Documentation

lartriton::TritonClient::TritonClient ( const fhicl::ParameterSet params)

Definition at line 25 of file TritonClient.cc.

References client_, fhicl::ParameterSet::get(), input_, inputsTriton_, maxBatchSize_, MF_LOG_INFO, noBatch_, options_, output_, outputsTriton_, triton_utils::printColl(), serverURL_, setBatchSize(), triton_utils::throwIfError(), and verbose_.

26  : allowedTries_(params.get<unsigned>("allowedTries", 0))
27  , serverURL_(params.get<std::string>("serverURL"))
28  , verbose_(params.get<bool>("verbose"))
29  , options_(params.get<std::string>("modelName"))
30  {
31  //get appropriate server for this model
32  if (verbose_) MF_LOG_INFO("TritonClient") << "Using server: " << serverURL_;
33 
34  //connect to the server
35  //TODO: add SSL options
36  triton_utils::throwIfError(nic::InferenceServerGrpcClient::Create(&client_, serverURL_, false),
37  "TritonClient(): unable to create inference context");
38 
39  //set options
40  options_.model_version_ = params.get<std::string>("modelVersion");
41  //convert seconds to microseconds
42  options_.client_timeout_ = params.get<unsigned>("timeout") * 1e6;
43 
44  //config needed for batch size
45  inference::ModelConfigResponse modelConfigResponse;
47  client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
48  "TritonClient(): unable to get model config");
49  inference::ModelConfig modelConfig(modelConfigResponse.config());
50 
51  //check batch size limitations (after i/o setup)
52  //triton uses max batch size = 0 to denote a model that does not support batching
53  //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present
54  //so set the local max to 1 and keep track of "no batch" case
55  maxBatchSize_ = modelConfig.max_batch_size();
56  noBatch_ = maxBatchSize_ == 0;
57  maxBatchSize_ = std::max(1u, maxBatchSize_);
58 
59  //get model info
60  inference::ModelMetadataResponse modelMetadata;
62  client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
63  "TritonClient(): unable to get model metadata");
64 
65  //get input and output (which know their sizes)
66  const auto& nicInputs = modelMetadata.inputs();
67  const auto& nicOutputs = modelMetadata.outputs();
68 
69  //report all model errors at once
70  std::ostringstream msg;
71  std::string msg_str;
72 
73  //currently no use case is foreseen for a model with zero inputs or outputs
74  if (nicInputs.empty()) msg << "Model on server appears malformed (zero inputs)\n";
75 
76  if (nicOutputs.empty()) msg << "Model on server appears malformed (zero outputs)\n";
77 
78  //stop if errors
79  msg_str = msg.str();
80  if (!msg_str.empty()) throw cet::exception("ModelErrors") << msg_str;
81 
82  //setup input map
83  std::ostringstream io_msg;
84  if (verbose_)
85  io_msg << "Model inputs: "
86  << "\n";
87  inputsTriton_.reserve(nicInputs.size());
88  for (const auto& nicInput : nicInputs) {
89  const auto& iname = nicInput.name();
90  auto [curr_itr, success] = input_.try_emplace(iname, iname, nicInput, noBatch_);
91  auto& curr_input = curr_itr->second;
92  inputsTriton_.push_back(curr_input.data());
93  if (verbose_) {
94  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
95  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
96  }
97  }
98 
99  //allow selecting only some outputs from server
100  const auto& v_outputs = params.get<std::vector<std::string>>("outputs");
101  std::unordered_set<std::string> s_outputs(v_outputs.begin(), v_outputs.end());
102 
103  //setup output map
104  if (verbose_)
105  io_msg << "Model outputs: "
106  << "\n";
107  outputsTriton_.reserve(nicOutputs.size());
108  for (const auto& nicOutput : nicOutputs) {
109  const auto& oname = nicOutput.name();
110  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end()) continue;
111  auto [curr_itr, success] = output_.try_emplace(oname, oname, nicOutput, noBatch_);
112  auto& curr_output = curr_itr->second;
113  outputsTriton_.push_back(curr_output.data());
114  if (verbose_) {
115  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
116  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
117  }
118  if (!s_outputs.empty()) s_outputs.erase(oname);
119  }
120 
121  //check if any requested outputs were not available
122  if (!s_outputs.empty())
123  throw cet::exception("MissingOutput")
124  << "Some requested outputs were not available on the server: "
125  << triton_utils::printColl(s_outputs);
126 
127  //propagate batch size to inputs and outputs
128  setBatchSize(1);
129 
130  //print model info
131  if (verbose_) {
132  std::ostringstream model_msg;
133  model_msg << "Model name: " << options_.model_name_ << "\n"
134  << "Model version: " << options_.model_version_ << "\n"
135  << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
136  MF_LOG_INFO("TritonClient") << model_msg.str() << io_msg.str();
137  }
138  }
nic::InferOptions options_
Definition: TritonClient.h:83
std::string printColl(const C &coll, const std::string &delim)
Definition: triton_utils.cc:15
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:26
std::vector< const nic::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:79
TritonOutputMap output_
Definition: TritonClient.h:69
std::unique_ptr< nic::InferenceServerGrpcClient > client_
Definition: TritonClient.h:81
T get(std::string const &key) const
Definition: ParameterSet.h:314
TritonInputMap input_
Definition: TritonClient.h:68
#define MF_LOG_INFO(category)
bool setBatchSize(unsigned bsize)
std::vector< nic::InferInput * > inputsTriton_
Definition: TritonClient.h:78
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33

Member Function Documentation

unsigned lartriton::TritonClient::batchSize ( ) const
inline

Definition at line 39 of file TritonClient.h.

39 { return batchSize_; }
void lartriton::TritonClient::dispatch ( )
inline

Definition at line 44 of file TritonClient.h.

45  {
46  start();
47  evaluate();
48  }
void lartriton::TritonClient::evaluate ( )
protected

Definition at line 194 of file TritonClient.cc.

References batchSize_, client_, finish(), getResults(), getServerSideStatus(), inputsTriton_, MF_LOG_DEBUG, options_, outputsTriton_, reportServerSideStats(), summarizeServerStats(), t1, t2, verbose(), and triton_utils::warnIfError().

Referenced by finish().

195  {
196  //in case there is nothing to process
197  if (batchSize_ == 0) {
198  finish(true);
199  return;
200  }
201 
202  // Get the status of the server prior to the request being made.
203  const auto& start_status = getServerSideStatus();
204 
205  //blocking call
206  auto t1 = std::chrono::steady_clock::now();
207  nic::InferResult* results;
208  bool status =
210  "evaluate(): unable to run and/or get result");
211  if (!status) {
212  finish(false);
213  return;
214  }
215 
216  auto t2 = std::chrono::steady_clock::now();
217  MF_LOG_DEBUG("TritonClient")
218  << "Remote time: " << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
219 
220  const auto& end_status = getServerSideStatus();
221 
222  if (verbose()) {
223  const auto& stats = summarizeServerStats(start_status, end_status);
224  reportServerSideStats(stats);
225  }
226 
227  std::shared_ptr<nic::InferResult> results_ptr(results);
228  status = getResults(results_ptr);
229 
230  finish(status);
231  }
nic::InferOptions options_
Definition: TritonClient.h:83
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
TTree * t1
Definition: plottest35.C:26
void reportServerSideStats(const ServerSideStats &stats) const
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:31
microsecond microseconds
Alias for common language habits.
Definition: spacetime.h:119
void finish(bool success)
std::vector< const nic::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:79
inference::ModelStatistics getServerSideStatus() const
std::unique_ptr< nic::InferenceServerGrpcClient > client_
Definition: TritonClient.h:81
TTree * t2
Definition: plottest35.C:36
bool verbose() const
Definition: TritonClient.h:40
std::vector< nic::InferInput * > inputsTriton_
Definition: TritonClient.h:78
#define MF_LOG_DEBUG(id)
bool getResults(std::shared_ptr< nic::InferResult > results)
void lartriton::TritonClient::finish ( bool  success)
protected

Definition at line 233 of file TritonClient.cc.

References allowedTries_, evaluate(), and tries_.

Referenced by evaluate().

234  {
235  //retries are only allowed if no exception was raised
236  if (!success) {
237  ++tries_;
238  //if max retries has not been exceeded, call evaluate again
239  if (tries_ < allowedTries_) {
240  evaluate();
241  //avoid calling doneWaiting() twice
242  return;
243  }
244  //prepare an exception if exceeded
245  throw cet::exception("TritonClient")
246  << "call failed after max " << tries_ << " tries" << std::endl;
247  }
248  }
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
bool lartriton::TritonClient::getResults ( std::shared_ptr< nic::InferResult >  results)
protected

Definition at line 169 of file TritonClient.cc.

References output(), output_, and triton_utils::warnIfError().

Referenced by evaluate().

170  {
171  for (auto& [oname, output] : output_) {
172  //set shape here before output becomes const
173  if (output.variableDims()) {
174  std::vector<int64_t> tmp_shape;
175  bool status =
176  triton_utils::warnIfError(results->Shape(oname, &tmp_shape),
177  "getResults(): unable to get output shape for " + oname);
178  if (!status) return status;
179  output.setShape(tmp_shape, false);
180  }
181  //extend lifetime
182  output.setResult(results);
183  }
184 
185  return true;
186  }
const TritonOutputMap & output() const
Definition: TritonClient.h:38
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:31
TritonOutputMap output_
Definition: TritonClient.h:69
inference::ModelStatistics lartriton::TritonClient::getServerSideStatus ( ) const
protected

Definition at line 313 of file TritonClient.cc.

References client_, options_, verbose_, and triton_utils::warnIfError().

Referenced by evaluate().

314  {
315  if (verbose_) {
316  inference::ModelStatisticsResponse resp;
317  bool success = triton_utils::warnIfError(
318  client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
319  "getServerSideStatus(): unable to get model statistics");
320  if (success) return *(resp.model_stats().begin());
321  }
322  return inference::ModelStatistics{};
323  }
nic::InferOptions options_
Definition: TritonClient.h:83
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:31
std::unique_ptr< nic::InferenceServerGrpcClient > client_
Definition: TritonClient.h:81
TritonInputMap& lartriton::TritonClient::input ( )
inline

Definition at line 37 of file TritonClient.h.

37 { return input_; }
TritonInputMap input_
Definition: TritonClient.h:68
const TritonOutputMap& lartriton::TritonClient::output ( ) const
inline

Definition at line 38 of file TritonClient.h.

Referenced by getResults().

38 { return output_; }
TritonOutputMap output_
Definition: TritonClient.h:69
void lartriton::TritonClient::reportServerSideStats ( const ServerSideStats stats) const
protected

Definition at line 250 of file TritonClient.cc.

References lartriton::TritonClient::ServerSideStats::compute_infer_time_ns_, lartriton::TritonClient::ServerSideStats::compute_input_time_ns_, lartriton::TritonClient::ServerSideStats::compute_output_time_ns_, lartriton::TritonClient::ServerSideStats::cumm_time_ns_, lartriton::TritonClient::ServerSideStats::execution_count_, lartriton::TritonClient::ServerSideStats::inference_count_, MF_LOG_DEBUG, lartriton::TritonClient::ServerSideStats::queue_time_ns_, and lartriton::TritonClient::ServerSideStats::success_count_.

Referenced by evaluate().

251  {
252  std::ostringstream msg;
253 
254  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
255  const uint64_t count = stats.success_count_;
256  msg << " Inference count: " << stats.inference_count_ << "\n";
257  msg << " Execution count: " << stats.execution_count_ << "\n";
258  msg << " Successful request count: " << count << "\n";
259 
260  if (count > 0) {
261  auto get_avg_us = [count](uint64_t tval) {
262  constexpr uint64_t us_to_ns = 1000;
263  return tval / us_to_ns / count;
264  };
265 
266  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
267  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
268  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
269  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
270  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
271  const uint64_t compute_avg_us =
272  compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
273  const uint64_t overhead = (cumm_avg_us > queue_avg_us + compute_avg_us) ?
274  (cumm_avg_us - queue_avg_us - compute_avg_us) :
275  0;
276 
277  msg << " Avg request latency: " << cumm_avg_us << " usec"
278  << "\n"
279  << " (overhead " << overhead << " usec + "
280  << "queue " << queue_avg_us << " usec + "
281  << "compute input " << compute_input_avg_us << " usec + "
282  << "compute infer " << compute_infer_avg_us << " usec + "
283  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
284  }
285 
286  MF_LOG_DEBUG("TritonClient") << msg.str();
287  }
#define MF_LOG_DEBUG(id)
void lartriton::TritonClient::reset ( )

Definition at line 159 of file TritonClient.cc.

References input_, and output_.

160  {
161  for (auto& element : input_) {
162  element.second.reset();
163  }
164  for (auto& element : output_) {
165  element.second.reset();
166  }
167  }
TritonOutputMap output_
Definition: TritonClient.h:69
TritonInputMap input_
Definition: TritonClient.h:68
bool lartriton::TritonClient::setBatchSize ( unsigned  bsize)

Definition at line 140 of file TritonClient.cc.

References batchSize_, input_, maxBatchSize_, MF_LOG_WARNING, and output_.

Referenced by TritonClient().

141  {
142  if (bsize > maxBatchSize_) {
143  MF_LOG_WARNING("TritonClient")
144  << "Requested batch size " << bsize << " exceeds server-specified max batch size "
145  << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
146  return false;
147  }
148  batchSize_ = bsize;
149  //set for input and output
150  for (auto& element : input_) {
151  element.second.setBatchSize(bsize);
152  }
153  for (auto& element : output_) {
154  element.second.setBatchSize(bsize);
155  }
156  return true;
157  }
TritonOutputMap output_
Definition: TritonClient.h:69
TritonInputMap input_
Definition: TritonClient.h:68
#define MF_LOG_WARNING(category)
void lartriton::TritonClient::start ( )
protected

Definition at line 188 of file TritonClient.cc.

References tries_.

189  {
190  tries_ = 0;
191  }
TritonClient::ServerSideStats lartriton::TritonClient::summarizeServerStats ( const inference::ModelStatistics &  start_status,
const inference::ModelStatistics &  end_status 
) const
protected

Definition at line 289 of file TritonClient.cc.

References lartriton::TritonClient::ServerSideStats::compute_infer_time_ns_, lartriton::TritonClient::ServerSideStats::compute_input_time_ns_, lartriton::TritonClient::ServerSideStats::compute_output_time_ns_, lartriton::TritonClient::ServerSideStats::cumm_time_ns_, lartriton::TritonClient::ServerSideStats::execution_count_, lartriton::TritonClient::ServerSideStats::inference_count_, lartriton::TritonClient::ServerSideStats::queue_time_ns_, and lartriton::TritonClient::ServerSideStats::success_count_.

Referenced by evaluate().

292  {
293  TritonClient::ServerSideStats server_stats;
294 
295  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
296  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
297  server_stats.success_count_ = end_status.inference_stats().success().count() -
298  start_status.inference_stats().success().count();
299  server_stats.cumm_time_ns_ =
300  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
301  server_stats.queue_time_ns_ =
302  end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
303  server_stats.compute_input_time_ns_ = end_status.inference_stats().compute_input().ns() -
304  start_status.inference_stats().compute_input().ns();
305  server_stats.compute_infer_time_ns_ = end_status.inference_stats().compute_infer().ns() -
306  start_status.inference_stats().compute_infer().ns();
307  server_stats.compute_output_time_ns_ = end_status.inference_stats().compute_output().ns() -
308  start_status.inference_stats().compute_output().ns();
309 
310  return server_stats;
311  }
bool lartriton::TritonClient::verbose ( ) const
inline

Definition at line 40 of file TritonClient.h.

Referenced by evaluate().

40 { return verbose_; }

Member Data Documentation

unsigned lartriton::TritonClient::allowedTries_
protected

Definition at line 70 of file TritonClient.h.

Referenced by finish().

unsigned lartriton::TritonClient::batchSize_
protected

Definition at line 73 of file TritonClient.h.

Referenced by evaluate(), and setBatchSize().

std::unique_ptr<nic::InferenceServerGrpcClient> lartriton::TritonClient::client_
protected

Definition at line 81 of file TritonClient.h.

Referenced by evaluate(), getServerSideStatus(), and TritonClient().

TritonInputMap lartriton::TritonClient::input_
protected

Definition at line 68 of file TritonClient.h.

Referenced by reset(), setBatchSize(), and TritonClient().

std::vector<nic::InferInput*> lartriton::TritonClient::inputsTriton_
protected

Definition at line 78 of file TritonClient.h.

Referenced by evaluate(), and TritonClient().

unsigned lartriton::TritonClient::maxBatchSize_
protected

Definition at line 72 of file TritonClient.h.

Referenced by setBatchSize(), and TritonClient().

bool lartriton::TritonClient::noBatch_
protected

Definition at line 74 of file TritonClient.h.

Referenced by TritonClient().

nic::InferOptions lartriton::TritonClient::options_
protected

Definition at line 83 of file TritonClient.h.

Referenced by evaluate(), getServerSideStatus(), and TritonClient().

TritonOutputMap lartriton::TritonClient::output_
protected

Definition at line 69 of file TritonClient.h.

Referenced by getResults(), reset(), setBatchSize(), and TritonClient().

std::vector<const nic::InferRequestedOutput*> lartriton::TritonClient::outputsTriton_
protected

Definition at line 79 of file TritonClient.h.

Referenced by evaluate(), and TritonClient().

std::string lartriton::TritonClient::serverURL_
protected

Definition at line 71 of file TritonClient.h.

Referenced by TritonClient().

unsigned lartriton::TritonClient::tries_
protected

Definition at line 70 of file TritonClient.h.

Referenced by finish(), and start().

bool lartriton::TritonClient::verbose_
protected

Definition at line 75 of file TritonClient.h.

Referenced by getServerSideStatus(), and TritonClient().


The documentation for this class was generated from the following files: