4 #include "cetlib_except/exception.h" 8 #include "grpc_client.h" 18 namespace nic = triton::client;
26 : allowedTries_(params.
get<unsigned>(
"allowedTries", 0))
27 , serverURL_(params.
get<
std::string>(
"serverURL"))
28 , verbose_(params.
get<bool>(
"verbose"))
29 , options_(params.
get<
std::string>(
"modelName"))
37 "TritonClient(): unable to create inference context");
40 options_.model_version_ = params.
get<std::string>(
"modelVersion");
42 options_.client_timeout_ = params.
get<
unsigned>(
"timeout") * 1e6;
45 inference::ModelConfigResponse modelConfigResponse;
48 "TritonClient(): unable to get model config");
49 inference::ModelConfig modelConfig(modelConfigResponse.config());
57 maxBatchSize_ = std::max(1u, maxBatchSize_);
60 inference::ModelMetadataResponse modelMetadata;
63 "TritonClient(): unable to get model metadata");
66 const auto& nicInputs = modelMetadata.inputs();
67 const auto& nicOutputs = modelMetadata.outputs();
70 std::ostringstream msg;
74 if (nicInputs.empty()) msg <<
"Model on server appears malformed (zero inputs)\n";
76 if (nicOutputs.empty()) msg <<
"Model on server appears malformed (zero outputs)\n";
80 if (!msg_str.empty())
throw cet::exception(
"ModelErrors") << msg_str;
83 std::ostringstream io_msg;
85 io_msg <<
"Model inputs: " 88 for (
const auto& nicInput : nicInputs) {
89 const auto& iname = nicInput.name();
90 auto [curr_itr, success] =
input_.try_emplace(iname, iname, nicInput,
noBatch_);
91 auto& curr_input = curr_itr->second;
94 io_msg <<
" " << iname <<
" (" << curr_input.dname() <<
", " << curr_input.byteSize()
100 const auto& v_outputs = params.
get<std::vector<std::string>>(
"outputs");
101 std::unordered_set<std::string> s_outputs(v_outputs.begin(), v_outputs.end());
105 io_msg <<
"Model outputs: " 108 for (
const auto& nicOutput : nicOutputs) {
109 const auto& oname = nicOutput.name();
110 if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
continue;
111 auto [curr_itr, success] =
output_.try_emplace(oname, oname, nicOutput,
noBatch_);
112 auto& curr_output = curr_itr->second;
115 io_msg <<
" " << oname <<
" (" << curr_output.dname() <<
", " << curr_output.byteSize()
118 if (!s_outputs.empty()) s_outputs.erase(oname);
122 if (!s_outputs.empty())
124 <<
"Some requested outputs were not available on the server: " 132 std::ostringstream model_msg;
133 model_msg <<
"Model name: " <<
options_.model_name_ <<
"\n" 134 <<
"Model version: " <<
options_.model_version_ <<
"\n" 136 MF_LOG_INFO(
"TritonClient") << model_msg.str() << io_msg.str();
144 <<
"Requested batch size " << bsize <<
" exceeds server-specified max batch size " 150 for (
auto& element :
input_) {
151 element.second.setBatchSize(bsize);
153 for (
auto& element :
output_) {
154 element.second.setBatchSize(bsize);
161 for (
auto& element :
input_) {
162 element.second.reset();
164 for (
auto& element :
output_) {
165 element.second.reset();
173 if (
output.variableDims()) {
174 std::vector<int64_t> tmp_shape;
177 "getResults(): unable to get output shape for " + oname);
178 if (!status)
return status;
179 output.setShape(tmp_shape,
false);
182 output.setResult(results);
206 auto t1 = std::chrono::steady_clock::now();
207 nic::InferResult* results;
210 "evaluate(): unable to run and/or get result");
216 auto t2 = std::chrono::steady_clock::now();
227 std::shared_ptr<nic::InferResult> results_ptr(results);
246 <<
"call failed after max " <<
tries_ <<
" tries" << std::endl;
252 std::ostringstream msg;
258 msg <<
" Successful request count: " << count <<
"\n";
261 auto get_avg_us = [count](uint64_t tval) {
262 constexpr uint64_t us_to_ns = 1000;
263 return tval / us_to_ns / count;
266 const uint64_t cumm_avg_us = get_avg_us(stats.
cumm_time_ns_);
271 const uint64_t compute_avg_us =
272 compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
273 const uint64_t overhead = (cumm_avg_us > queue_avg_us + compute_avg_us) ?
274 (cumm_avg_us - queue_avg_us - compute_avg_us) :
277 msg <<
" Avg request latency: " << cumm_avg_us <<
" usec" 279 <<
" (overhead " << overhead <<
" usec + " 280 <<
"queue " << queue_avg_us <<
" usec + " 281 <<
"compute input " << compute_input_avg_us <<
" usec + " 282 <<
"compute infer " << compute_infer_avg_us <<
" usec + " 283 <<
"compute output " << compute_output_avg_us <<
" usec)" << std::endl;
290 const inference::ModelStatistics& start_status,
291 const inference::ModelStatistics& end_status)
const 295 server_stats.
inference_count_ = end_status.inference_count() - start_status.inference_count();
296 server_stats.
execution_count_ = end_status.execution_count() - start_status.execution_count();
297 server_stats.
success_count_ = end_status.inference_stats().success().count() -
298 start_status.inference_stats().success().count();
300 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
302 end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
304 start_status.inference_stats().compute_input().ns();
306 start_status.inference_stats().compute_infer().ns();
308 start_status.inference_stats().compute_output().ns();
316 inference::ModelStatisticsResponse resp;
319 "getServerSideStatus(): unable to get model statistics");
320 if (success)
return *(resp.model_stats().begin());
322 return inference::ModelStatistics{};
nic::InferOptions options_
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
uint64_t compute_output_time_ns_
std::string printColl(const C &coll, const std::string &delim)
void reportServerSideStats(const ServerSideStats &stats) const
const TritonOutputMap & output() const
bool warnIfError(const Error &err, std::string_view msg)
uint64_t compute_infer_time_ns_
uint64_t compute_input_time_ns_
microsecond microseconds
Alias for common language habits.
void throwIfError(const Error &err, std::string_view msg)
void finish(bool success)
std::vector< const nic::InferRequestedOutput * > outputsTriton_
inference::ModelStatistics getServerSideStatus() const
std::unique_ptr< nic::InferenceServerGrpcClient > client_
uint64_t inference_count_
T get(std::string const &key) const
#define MF_LOG_INFO(category)
bool setBatchSize(unsigned bsize)
decltype(auto) get(T &&obj)
ADL-aware version of std::to_string.
std::vector< nic::InferInput * > inputsTriton_
TritonClient(const fhicl::ParameterSet ¶ms)
uint64_t execution_count_
bool getResults(std::shared_ptr< nic::InferResult > results)
#define MF_LOG_WARNING(category)
cet::coded_exception< error, detail::translate > exception