4 #include "cetlib_except/exception.h" 8 #include "grpc_client.h" 18 namespace nic = triton::client;
26 : allowedTries_(params.
get<unsigned>(
"allowedTries", 0))
27 , serverURL_(params.
get<
std::string>(
"serverURL"))
28 , verbose_(params.
get<bool>(
"verbose"))
29 , ssl_(params.
get<bool>(
"ssl", false))
30 , sslRootCertificates_(params.
get<
std::string>(
"sslRootCertificates",
""))
31 , sslPrivateKey_(params.
get<
std::string>(
"sslPrivateKey",
""))
32 , sslCertificateChain_(params.
get<
std::string>(
"sslCertificateChain",
""))
33 , options_(params.
get<
std::string>(
"modelName"))
40 nic::SslOptions ssl_options = nic::SslOptions();
45 nic::InferenceServerGrpcClient::Create(
47 "TritonClient(): unable to create inference context");
52 "TritonClient(): unable to create inference context");
56 options_.model_version_ = params.
get<std::string>(
"modelVersion");
58 options_.client_timeout_ = params.
get<
unsigned>(
"timeout") * 1e6;
61 inference::ModelConfigResponse modelConfigResponse;
64 "TritonClient(): unable to get model config");
65 inference::ModelConfig modelConfig(modelConfigResponse.config());
73 maxBatchSize_ = std::max(1u, maxBatchSize_);
76 inference::ModelMetadataResponse modelMetadata;
79 "TritonClient(): unable to get model metadata");
82 const auto& nicInputs = modelMetadata.inputs();
83 const auto& nicOutputs = modelMetadata.outputs();
86 std::ostringstream msg;
90 if (nicInputs.empty()) msg <<
"Model on server appears malformed (zero inputs)\n";
92 if (nicOutputs.empty()) msg <<
"Model on server appears malformed (zero outputs)\n";
96 if (!msg_str.empty())
throw cet::exception(
"ModelErrors") << msg_str;
99 std::ostringstream io_msg;
101 io_msg <<
"Model inputs: " 104 for (
const auto& nicInput : nicInputs) {
105 const auto& iname = nicInput.name();
106 auto [curr_itr, success] =
input_.try_emplace(iname, iname, nicInput,
noBatch_);
107 auto& curr_input = curr_itr->second;
110 io_msg <<
" " << iname <<
" (" << curr_input.dname() <<
", " << curr_input.byteSize()
116 const auto& v_outputs = params.
get<std::vector<std::string>>(
"outputs");
117 std::unordered_set<std::string> s_outputs(v_outputs.begin(), v_outputs.end());
121 io_msg <<
"Model outputs: " 124 for (
const auto& nicOutput : nicOutputs) {
125 const auto& oname = nicOutput.name();
126 if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
continue;
127 auto [curr_itr, success] =
output_.try_emplace(oname, oname, nicOutput,
noBatch_);
128 auto& curr_output = curr_itr->second;
131 io_msg <<
" " << oname <<
" (" << curr_output.dname() <<
", " << curr_output.byteSize()
134 if (!s_outputs.empty()) s_outputs.erase(oname);
138 if (!s_outputs.empty())
140 <<
"Some requested outputs were not available on the server: " 148 std::ostringstream model_msg;
149 model_msg <<
"Model name: " <<
options_.model_name_ <<
"\n" 150 <<
"Model version: " <<
options_.model_version_ <<
"\n" 152 MF_LOG_INFO(
"TritonClient") << model_msg.str() << io_msg.str();
160 <<
"Requested batch size " << bsize <<
" exceeds server-specified max batch size " 166 for (
auto& element :
input_) {
167 element.second.setBatchSize(bsize);
169 for (
auto& element :
output_) {
170 element.second.setBatchSize(bsize);
177 for (
auto& element :
input_) {
178 element.second.reset();
180 for (
auto& element :
output_) {
181 element.second.reset();
189 if (
output.variableDims()) {
190 std::vector<int64_t> tmp_shape;
193 "getResults(): unable to get output shape for " + oname);
194 if (!status)
return status;
195 output.setShape(tmp_shape,
false);
198 output.setResult(results);
222 auto t1 = std::chrono::steady_clock::now();
223 nic::InferResult* results;
225 nic::Headers http_headers;
226 grpc_compression_algorithm compression_algorithm =
227 grpc_compression_algorithm::GRPC_COMPRESS_NONE;
232 "evaluate(): unable to run and/or get result");
238 auto t2 = std::chrono::steady_clock::now();
249 std::shared_ptr<nic::InferResult> results_ptr(results);
268 <<
"call failed after max " <<
tries_ <<
" tries" << std::endl;
274 std::ostringstream msg;
280 msg <<
" Successful request count: " << count <<
"\n";
283 auto get_avg_us = [count](uint64_t tval) {
284 constexpr uint64_t us_to_ns = 1000;
285 return tval / us_to_ns / count;
288 const uint64_t cumm_avg_us = get_avg_us(stats.
cumm_time_ns_);
293 const uint64_t compute_avg_us =
294 compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
295 const uint64_t overhead = (cumm_avg_us > queue_avg_us + compute_avg_us) ?
296 (cumm_avg_us - queue_avg_us - compute_avg_us) :
299 msg <<
" Avg request latency: " << cumm_avg_us <<
" usec" 301 <<
" (overhead " << overhead <<
" usec + " 302 <<
"queue " << queue_avg_us <<
" usec + " 303 <<
"compute input " << compute_input_avg_us <<
" usec + " 304 <<
"compute infer " << compute_infer_avg_us <<
" usec + " 305 <<
"compute output " << compute_output_avg_us <<
" usec)" << std::endl;
312 const inference::ModelStatistics& start_status,
313 const inference::ModelStatistics& end_status)
const 317 server_stats.
inference_count_ = end_status.inference_count() - start_status.inference_count();
318 server_stats.
execution_count_ = end_status.execution_count() - start_status.execution_count();
319 server_stats.
success_count_ = end_status.inference_stats().success().count() -
320 start_status.inference_stats().success().count();
322 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
324 end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
326 start_status.inference_stats().compute_input().ns();
328 start_status.inference_stats().compute_infer().ns();
330 start_status.inference_stats().compute_output().ns();
338 inference::ModelStatisticsResponse resp;
341 "getServerSideStatus(): unable to get model statistics");
342 if (success)
return *(resp.model_stats().begin());
344 return inference::ModelStatistics{};
nic::InferOptions options_
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
uint64_t compute_output_time_ns_
std::string printColl(const C &coll, const std::string &delim)
void reportServerSideStats(const ServerSideStats &stats) const
const TritonOutputMap & output() const
bool warnIfError(const Error &err, std::string_view msg)
uint64_t compute_infer_time_ns_
uint64_t compute_input_time_ns_
microsecond microseconds
Alias for common language habits.
void throwIfError(const Error &err, std::string_view msg)
void finish(bool success)
std::vector< const nic::InferRequestedOutput * > outputsTriton_
std::string sslPrivateKey_
std::string sslRootCertificates_
inference::ModelStatistics getServerSideStatus() const
std::unique_ptr< nic::InferenceServerGrpcClient > client_
uint64_t inference_count_
T get(std::string const &key) const
#define MF_LOG_INFO(category)
bool setBatchSize(unsigned bsize)
decltype(auto) get(T &&obj)
ADL-aware version of std::to_string.
std::vector< nic::InferInput * > inputsTriton_
TritonClient(const fhicl::ParameterSet ¶ms)
uint64_t execution_count_
bool getResults(std::shared_ptr< nic::InferResult > results)
#define MF_LOG_WARNING(category)
cet::coded_exception< error, detail::translate > exception
std::string sslCertificateChain_