LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
lar_dl_content::LArDLHelper Class Reference

LArDLHelper class. More...

#include "LArDLHelper.h"

Public Types

typedef torch::jit::script::Module TorchModel
 
typedef torch::Tensor TorchInput
 
typedef std::vector< torch::jit::IValue > TorchInputVector
 
typedef at::Tensor TorchOutput
 

Static Public Member Functions

static pandora::StatusCode LoadModel (const std::string &filename, TorchModel &model)
 Loads a deep learning model. More...
 
static void InitialiseInput (const at::IntArrayRef dimensions, TorchInput &tensor)
 Create a torch input tensor. More...
 
static void Forward (TorchModel &model, const TorchInputVector &input, TorchOutput &output)
 Run a deep learning model. More...
 

Detailed Description

LArDLHelper class.

Definition at line 22 of file LArDLHelper.h.

Member Typedef Documentation

Definition at line 26 of file LArDLHelper.h.

typedef std::vector<torch::jit::IValue> lar_dl_content::LArDLHelper::TorchInputVector

Definition at line 27 of file LArDLHelper.h.

typedef torch::jit::script::Module lar_dl_content::LArDLHelper::TorchModel

Definition at line 25 of file LArDLHelper.h.

Definition at line 28 of file LArDLHelper.h.

Member Function Documentation

void lar_dl_content::LArDLHelper::Forward ( TorchModel model,
const TorchInputVector input,
TorchOutput output 
)
static

Run a deep learning model.

Parameters
modelthe model to run
inputthe input to run over
outputthe tensor to store the output in

Definition at line 41 of file LArDLHelper.cc.

Referenced by lar_dl_content::DlHitTrackShowerIdAlgorithm::Infer(), and lar_dl_content::DlVertexingAlgorithm::Infer().

42 {
43  output = model.forward(input).toTensor();
44 }
void lar_dl_content::LArDLHelper::InitialiseInput ( const at::IntArrayRef  dimensions,
TorchInput tensor 
)
static

Create a torch input tensor.

Parameters
dimensionsthe size of each dimension of the tensor: pass as {a, b, c, d} for example
tensorthe tensor to be initialised

Definition at line 34 of file LArDLHelper.cc.

Referenced by lar_dl_content::DlHitTrackShowerIdAlgorithm::Infer(), and lar_dl_content::DlVertexingAlgorithm::MakeNetworkInputFromHits().

35 {
36  tensor = torch::zeros(dimensions);
37 }
StatusCode lar_dl_content::LArDLHelper::LoadModel ( const std::string &  filename,
LArDLHelper::TorchModel model 
)
static

Loads a deep learning model.

Parameters
filenamethe filename of the model to load
modelthe TorchModel in which to store the loaded model
Returns
STATUS_CODE_SUCCESS upon successful loading of the model. STATUS_CODE_FAILURE otherwise.

Definition at line 16 of file LArDLHelper.cc.

References e.

Referenced by lar_dl_content::DlHitTrackShowerIdAlgorithm::ReadSettings(), and lar_dl_content::DlVertexingAlgorithm::ReadSettings().

17 {
18  try
19  {
20  model = torch::jit::load(filename);
21  std::cout << "Loaded the TorchScript model \'" << filename << "\'" << std::endl;
22  }
23  catch (const std::exception &e)
24  {
25  std::cout << "Error loading the TorchScript model \'" << filename << "\':\n" << e.what() << std::endl;
26  return STATUS_CODE_FAILURE;
27  }
28 
29  return STATUS_CODE_SUCCESS;
30 }
Float_t e
Definition: plot.C:35
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33

The documentation for this class was generated from the following files: