LArSoft  v10_04_05
Liquid Argon Software toolkit - https://larsoft.org/
DlVertexingBaseAlgorithm.cc
Go to the documentation of this file.
1 
9 #include <chrono>
10 #include <cmath>
11 
12 #include <torch/script.h>
13 #include <torch/torch.h>
14 
19 
21 
23 
24 using namespace pandora;
25 using namespace lar_content;
26 
27 namespace lar_dl_content
28 {
29 
30 DlVertexingBaseAlgorithm::DlVertexingBaseAlgorithm() :
31  m_trainingMode{false},
33  m_pass{1},
34  m_nClasses{0},
35  m_height{256},
36  m_width{256},
37  m_driftStep{0.5f},
38  m_volumeType{"dune_fd_hd"}
39 {
40 }
41 
42 //-----------------------------------------------------------------------------------------------------------------------------------------
43 
45 {
46 }
47 
48 //-----------------------------------------------------------------------------------------------------------------------------------------
49 
50 void DlVertexingBaseAlgorithm::GetHitRegion(const CaloHitList &caloHitList, float &xMin, float &xMax, float &zMin, float &zMax) const
51 {
52  xMin = std::numeric_limits<float>::max();
53  xMax = -std::numeric_limits<float>::max();
54  zMin = std::numeric_limits<float>::max();
55  zMax = -std::numeric_limits<float>::max();
56  // Find the range of x and z values in the view
57  for (const CaloHit *pCaloHit : caloHitList)
58  {
59  const float x{pCaloHit->GetPositionVector().GetX()};
60  const float z{pCaloHit->GetPositionVector().GetZ()};
61  xMin = std::min(x, xMin);
62  xMax = std::max(x, xMax);
63  zMin = std::min(z, zMin);
64  zMax = std::max(z, zMax);
65  }
66 
67  if (caloHitList.empty())
68  throw StatusCodeException(STATUS_CODE_NOT_FOUND);
69 
70  const HitType view{caloHitList.front()->GetHitType()};
71  const bool isU{view == TPC_VIEW_U}, isV{view == TPC_VIEW_V}, isW{view == TPC_VIEW_W};
72  if (!(isU || isV || isW))
73  throw StatusCodeException(STATUS_CODE_NOT_ALLOWED);
74 
75  // ATTN If wire w pitches vary between TPCs, exception will be raised in initialisation of lar pseudolayer plugin
76  const LArTPC *const pTPC(this->GetPandora().GetGeometry()->GetLArTPCMap().begin()->second);
77  const float pitch(view == TPC_VIEW_U ? pTPC->GetWirePitchU() : view == TPC_VIEW_V ? pTPC->GetWirePitchV() : pTPC->GetWirePitchW());
78 
79  if (m_pass > 1)
80  {
81  const VertexList *pVertexList(nullptr);
82  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*this, m_inputVertexListName, pVertexList));
83  if (pVertexList->empty())
84  throw StatusCodeException(STATUS_CODE_NOT_FOUND);
85  const CartesianVector &vertex{pVertexList->front()->GetPosition()};
86 
87  // Get hit distribution left/right asymmetry
88  int nHitsLeft{0}, nHitsRight{0};
89  const double xVtx{vertex.GetX()};
90  for (const std::string &listname : m_caloHitListNames)
91  {
92  const CaloHitList *pCaloHitList(nullptr);
93  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*this, listname, pCaloHitList));
94  if (pCaloHitList->empty())
95  continue;
96  for (const CaloHit *const pCaloHit : *pCaloHitList)
97  {
98  const CartesianVector &pos{pCaloHit->GetPositionVector()};
99  if (pos.GetX() <= xVtx)
100  ++nHitsLeft;
101  else
102  ++nHitsRight;
103  }
104  }
105  const int nHitsTotal{nHitsLeft + nHitsRight};
106  if (nHitsTotal == 0)
107  throw StatusCodeException(STATUS_CODE_NOT_FOUND);
108  const float xAsymmetry{nHitsLeft / static_cast<float>(nHitsTotal)};
109 
110  // Vertices
111  const LArTransformationPlugin *transform{this->GetPandora().GetPlugins()->GetLArTransformationPlugin()};
112  double zVtx{0.};
113  if (isW)
114  zVtx += transform->YZtoW(vertex.GetY(), vertex.GetZ());
115  else if (isV)
116  zVtx += transform->YZtoV(vertex.GetY(), vertex.GetZ());
117  else
118  zVtx = transform->YZtoU(vertex.GetY(), vertex.GetZ());
119 
120  // Get hit distribution upstream/downstream asymmetry
121  int nHitsUpstream{0}, nHitsDownstream{0};
122  for (const CaloHit *const pCaloHit : caloHitList)
123  {
124  const CartesianVector &pos{pCaloHit->GetPositionVector()};
125  if (pos.GetZ() <= zVtx)
126  ++nHitsUpstream;
127  else
128  ++nHitsDownstream;
129  }
130  const int nHitsViewTotal{nHitsUpstream + nHitsDownstream};
131  if (nHitsViewTotal == 0)
132  throw StatusCodeException(STATUS_CODE_NOT_FOUND);
133  const float zAsymmetry{nHitsUpstream / static_cast<float>(nHitsViewTotal)};
134 
135  const float xSpan{m_driftStep * (m_width - 1)};
136  xMin = xVtx - xAsymmetry * xSpan;
137  xMax = xMin + (m_driftStep * (m_width - 1));
138  const float zSpan{pitch * (m_height - 1)};
139  zMin = zVtx - zAsymmetry * zSpan;
140  zMax = zMin + zSpan;
141  }
142 
143  // Avoid unreasonable rescaling of very small hit regions, pixels are assumed to be 0.5cm in x and wire pitch in z
144  // ATTN: Rescaling is to a size 1 pixel smaller than the intended image to ensure all hits fit within an imaged binned
145  // to be one pixel wider than this
146  const float xRange{xMax - xMin}, zRange{zMax - zMin};
147  const float minXSpan{m_driftStep * (m_width - 1)};
148  if (xRange < minXSpan)
149  {
150  const float padding{0.5f * (minXSpan - xRange)};
151  xMin -= padding;
152  xMax += padding;
153  }
154  const float minZSpan{pitch * (m_height - 1)};
155  if (zRange < minZSpan)
156  {
157  const float padding{0.5f * (minZSpan - zRange)};
158  zMin -= padding;
159  zMax += padding;
160  }
161 }
162 
163 //-----------------------------------------------------------------------------------------------------------------------------------------
164 
166  int &colOffset, int &rowOffset, int &width, int &height) const
167 {
168  const double scaleFactor{std::sqrt(m_height * m_height + m_width * m_width)};
169  // output is a 1 x num_classes x height x width tensor
170  // we want the maximum value in the num_classes dimension (1) for every pixel
171  auto classes{torch::argmax(networkOutput, 1)};
172  // the argmax result is a 1 x height x width tensor where each element is a class id
173  auto classesAccessor{classes.accessor<int64_t, 3>()};
174  int colOffsetMin{0}, colOffsetMax{0}, rowOffsetMin{0}, rowOffsetMax{0};
175  for (const auto &[row, col] : pixelVector)
176  {
177  const auto cls{classesAccessor[0][row][col]};
178  const double threshold{m_thresholds[cls]};
179  if (threshold > 0. && threshold < 1.)
180  {
181  const int distance = static_cast<int>(std::round(std::ceil(scaleFactor * threshold)));
182  if ((row - distance) < rowOffsetMin)
183  rowOffsetMin = row - distance;
184  if ((row + distance) > rowOffsetMax)
185  rowOffsetMax = row + distance;
186  if ((col - distance) < colOffsetMin)
187  colOffsetMin = col - distance;
188  if ((col + distance) > colOffsetMax)
189  colOffsetMax = col + distance;
190  }
191  }
192  colOffset = colOffsetMin < 0 ? -colOffsetMin : 0;
193  rowOffset = rowOffsetMin < 0 ? -rowOffsetMin : 0;
194  width = std::max(colOffsetMax + colOffset + 1, m_width);
195  height = std::max(rowOffsetMax + rowOffset + 1, m_height);
196 }
197 
198 //-----------------------------------------------------------------------------------------------------------------------------------------
199 
200 StatusCode DlVertexingBaseAlgorithm::ReadSettings(const TiXmlHandle xmlHandle)
201 {
202  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "TrainingMode", m_trainingMode));
203  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "Pass", m_pass));
204  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "ImageHeight", m_height));
205  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "ImageWidth", m_width));
206  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "DriftStep", m_driftStep));
207  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadVectorOfValues(xmlHandle, "DistanceThresholds", m_thresholds));
208  m_nClasses = m_thresholds.size() - 1;
209  if (m_pass > 1)
210  {
211  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle, "InputVertexListName", m_inputVertexListName));
212  }
213 
214  if (m_trainingMode)
215  {
216  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle, "TrainingOutputFileName", m_trainingOutputFile));
217  }
218  else
219  {
220  std::string modelName;
221  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle, "ModelFileNameU", modelName));
222  modelName = LArFileHelper::FindFileInPath(modelName, "FW_SEARCH_PATH");
223  LArDLHelper::LoadModel(modelName, m_modelU);
224  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle, "ModelFileNameV", modelName));
225  modelName = LArFileHelper::FindFileInPath(modelName, "FW_SEARCH_PATH");
226  LArDLHelper::LoadModel(modelName, m_modelV);
227  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle, "ModelFileNameW", modelName));
228  modelName = LArFileHelper::FindFileInPath(modelName, "FW_SEARCH_PATH");
229  LArDLHelper::LoadModel(modelName, m_modelW);
230  PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle, "OutputVertexListName", m_outputVertexListName));
231  }
232 
233  PANDORA_RETURN_RESULT_IF_AND_IF(
234  STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(xmlHandle, "CaloHitListNames", m_caloHitListNames));
235  PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle, "VolumeType", m_volumeType));
236 
237  return STATUS_CODE_SUCCESS;
238 }
239 
240 } // namespace lar_dl_content
Float_t x
Definition: compare.C:6
int m_pass
The pass of the train/infer step.
Header file for the lar deep learning helper helper class.
pandora::StatusCode ReadSettings(const pandora::TiXmlHandle xmlHandle)
LArDLHelper::TorchModel m_modelU
The model for the U view.
Double_t z
Definition: plot.C:276
std::vector< double > m_thresholds
Distance class thresholds.
float m_driftStep
The size of a pixel in the drift direction in cm (most relevant in pass 2)
std::string m_trainingOutputFile
Output file name for training examples.
std::string m_outputVertexListName
Output vertex list name.
Header file for the geometry helper class.
Int_t col[ntarg]
Definition: Style.C:29
pandora::StringVector m_caloHitListNames
Names of input calo hit lists.
Header file for the file helper class.
std::string m_inputVertexListName
Input vertex list name if 2nd pass.
void GetHitRegion(const pandora::CaloHitList &caloHitList, float &xMin, float &xMax, float &zMin, float &zMax) const
void GetCanvasParameters(const LArDLHelper::TorchOutput &networkOutput, const PixelVector &pixelVector, int &columnOffset, int &rowOffset, int &width, int &height) const
Determines the parameters of the canvas for extracting the vertex location. The network predicts the ...
Header file for the vertex helper class.
LArDLHelper::TorchModel m_modelW
The model for the W view.
static pandora::StatusCode LoadModel(const std::string &filename, TorchModel &model)
Loads a deep learning model.
Definition: LArDLHelper.cc:16
HitType
Definition: HitType.h:12
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:69
int m_nClasses
The number of distance classes.
second_as<> second
Type of time stored in seconds, in double precision.
Definition: spacetime.h:82
std::list< Vertex > VertexList
Definition: DCEL.h:169
std::string m_volumeType
The name of the fiducial volume type for the monitoring output.
LArDLHelper::TorchModel m_modelV
The model for the V view.
vertex reconstruction