12 #include <torch/script.h> 13 #include <torch/torch.h> 30 DlVertexingBaseAlgorithm::DlVertexingBaseAlgorithm() :
31 m_trainingMode{
false},
52 xMin = std::numeric_limits<float>::max();
53 xMax = -std::numeric_limits<float>::max();
54 zMin = std::numeric_limits<float>::max();
55 zMax = -std::numeric_limits<float>::max();
57 for (
const CaloHit *pCaloHit : caloHitList)
59 const float x{pCaloHit->GetPositionVector().GetX()};
60 const float z{pCaloHit->GetPositionVector().GetZ()};
61 xMin = std::min(
x, xMin);
62 xMax = std::max(
x, xMax);
63 zMin = std::min(
z, zMin);
64 zMax = std::max(
z, zMax);
67 if (caloHitList.empty())
68 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
70 const HitType view{caloHitList.front()->GetHitType()};
71 const bool isU{view == TPC_VIEW_U}, isV{view == TPC_VIEW_V}, isW{view == TPC_VIEW_W};
72 if (!(isU || isV || isW))
73 throw StatusCodeException(STATUS_CODE_NOT_ALLOWED);
76 const LArTPC *
const pTPC(this->GetPandora().GetGeometry()->GetLArTPCMap().
begin()->
second);
77 const float pitch(view == TPC_VIEW_U ? pTPC->GetWirePitchU() : view == TPC_VIEW_V ? pTPC->GetWirePitchV() : pTPC->GetWirePitchW());
82 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*
this,
m_inputVertexListName, pVertexList));
83 if (pVertexList->empty())
84 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
85 const CartesianVector &
vertex{pVertexList->front()->GetPosition()};
88 int nHitsLeft{0}, nHitsRight{0};
89 const double xVtx{
vertex.GetX()};
92 const CaloHitList *pCaloHitList(
nullptr);
93 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*
this, listname, pCaloHitList));
94 if (pCaloHitList->empty())
96 for (
const CaloHit *
const pCaloHit : *pCaloHitList)
98 const CartesianVector &pos{pCaloHit->GetPositionVector()};
99 if (pos.GetX() <= xVtx)
105 const int nHitsTotal{nHitsLeft + nHitsRight};
107 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
108 const float xAsymmetry{nHitsLeft /
static_cast<float>(nHitsTotal)};
111 const LArTransformationPlugin *transform{this->GetPandora().GetPlugins()->GetLArTransformationPlugin()};
121 int nHitsUpstream{0}, nHitsDownstream{0};
122 for (
const CaloHit *
const pCaloHit : caloHitList)
124 const CartesianVector &pos{pCaloHit->GetPositionVector()};
125 if (pos.GetZ() <= zVtx)
130 const int nHitsViewTotal{nHitsUpstream + nHitsDownstream};
131 if (nHitsViewTotal == 0)
132 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
133 const float zAsymmetry{nHitsUpstream /
static_cast<float>(nHitsViewTotal)};
136 xMin = xVtx - xAsymmetry * xSpan;
138 const float zSpan{pitch * (
m_height - 1)};
139 zMin = zVtx - zAsymmetry * zSpan;
146 const float xRange{xMax - xMin}, zRange{zMax - zMin};
148 if (xRange < minXSpan)
150 const float padding{0.5f * (minXSpan - xRange)};
154 const float minZSpan{pitch * (
m_height - 1)};
155 if (zRange < minZSpan)
157 const float padding{0.5f * (minZSpan - zRange)};
166 int &colOffset,
int &rowOffset,
int &width,
int &height)
const 171 auto classes{torch::argmax(networkOutput, 1)};
173 auto classesAccessor{classes.accessor<int64_t, 3>()};
174 int colOffsetMin{0}, colOffsetMax{0}, rowOffsetMin{0}, rowOffsetMax{0};
175 for (
const auto &[row,
col] : pixelVector)
177 const auto cls{classesAccessor[0][row][
col]};
179 if (threshold > 0. && threshold < 1.)
181 const int distance =
static_cast<int>(std::round(std::ceil(scaleFactor * threshold)));
182 if ((row - distance) < rowOffsetMin)
183 rowOffsetMin = row - distance;
184 if ((row + distance) > rowOffsetMax)
185 rowOffsetMax = row + distance;
186 if ((
col - distance) < colOffsetMin)
187 colOffsetMin =
col - distance;
188 if ((
col + distance) > colOffsetMax)
189 colOffsetMax =
col + distance;
192 colOffset = colOffsetMin < 0 ? -colOffsetMin : 0;
193 rowOffset = rowOffsetMin < 0 ? -rowOffsetMin : 0;
194 width = std::max(colOffsetMax + colOffset + 1,
m_width);
195 height = std::max(rowOffsetMax + rowOffset + 1,
m_height);
202 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"TrainingMode",
m_trainingMode));
203 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"Pass",
m_pass));
204 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ImageHeight",
m_height));
205 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ImageWidth",
m_width));
206 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"DriftStep",
m_driftStep));
207 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadVectorOfValues(xmlHandle,
"DistanceThresholds",
m_thresholds));
211 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"InputVertexListName",
m_inputVertexListName));
216 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"TrainingOutputFileName",
m_trainingOutputFile));
220 std::string modelName;
221 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"ModelFileNameU", modelName));
222 modelName = LArFileHelper::FindFileInPath(modelName,
"FW_SEARCH_PATH");
224 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"ModelFileNameV", modelName));
225 modelName = LArFileHelper::FindFileInPath(modelName,
"FW_SEARCH_PATH");
227 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"ModelFileNameW", modelName));
228 modelName = LArFileHelper::FindFileInPath(modelName,
"FW_SEARCH_PATH");
230 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"OutputVertexListName",
m_outputVertexListName));
233 PANDORA_RETURN_RESULT_IF_AND_IF(
234 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(xmlHandle,
"CaloHitListNames",
m_caloHitListNames));
235 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"VolumeType",
m_volumeType));
237 return STATUS_CODE_SUCCESS;
int m_pass
The pass of the train/infer step.
Header file for the lar deep learning helper helper class.
pandora::StatusCode ReadSettings(const pandora::TiXmlHandle xmlHandle)
~DlVertexingBaseAlgorithm()
LArDLHelper::TorchModel m_modelU
The model for the U view.
std::vector< double > m_thresholds
Distance class thresholds.
int m_width
The width of the images.
float m_driftStep
The size of a pixel in the drift direction in cm (most relevant in pass 2)
std::string m_trainingOutputFile
Output file name for training examples.
std::string m_outputVertexListName
Output vertex list name.
Header file for the geometry helper class.
std::vector< Pixel > PixelVector
pandora::StringVector m_caloHitListNames
Names of input calo hit lists.
Header file for the file helper class.
std::string m_inputVertexListName
Input vertex list name if 2nd pass.
void GetHitRegion(const pandora::CaloHitList &caloHitList, float &xMin, float &xMax, float &zMin, float &zMax) const
void GetCanvasParameters(const LArDLHelper::TorchOutput &networkOutput, const PixelVector &pixelVector, int &columnOffset, int &rowOffset, int &width, int &height) const
Determines the parameters of the canvas for extracting the vertex location. The network predicts the ...
Header file for the vertex helper class.
LArDLHelper::TorchModel m_modelW
The model for the W view.
static pandora::StatusCode LoadModel(const std::string &filename, TorchModel &model)
Loads a deep learning model.
int m_height
The height of the images.
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
int m_nClasses
The number of distance classes.
second_as<> second
Type of time stored in seconds, in double precision.
std::list< Vertex > VertexList
std::string m_volumeType
The name of the fiducial volume type for the monitoring output.
LArDLHelper::TorchModel m_modelV
The model for the V view.
bool m_trainingMode
Training mode.