12 #include <torch/script.h> 13 #include <torch/torch.h> 15 #include "Pandora/Pandora.h" 31 DlSNSignalAlgorithm::DlSNSignalAlgorithm() :
32 m_trainingMode{
false},
62 catch (StatusCodeException
e)
64 std::cout <<
"SignalAssessmentAlgorithm: Unable to write to ROOT tree" << std::endl;
75 PANDORA_MONITORING_API(SetEveDisplayParameters(this->GetPandora(),
true, DETECTOR_VIEW_XZ, -1.
f, 1.
f, 1.
f));
88 return STATUS_CODE_SUCCESS;
94 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, this->
GetMCToHitsMap(mcToHitsMap));
95 if (mcToHitsMap.empty())
96 throw StatusCodeException(STATUS_CODE_NOT_ALLOWED);
99 std::map<int, float> wireMin, wireMax;
100 std::map<int, bool> viewCalculated;
101 float driftMin{std::numeric_limits<float>::max()}, driftMax{-std::numeric_limits<float>::max()};
104 const CaloHitList *pCaloHitList{
nullptr};
107 PandoraContentApi::GetList(*
this, listname, pCaloHitList);
109 catch (
const StatusCodeException &
e)
113 if (!pCaloHitList || pCaloHitList->empty())
116 HitType view{pCaloHitList->front()->GetHitType()};
117 float viewDriftMin{driftMin}, viewDriftMax{driftMax};
120 this->
GetHitRegion(*pCaloHitList, viewDriftMin, viewDriftMax, wireMin[view], wireMax[view]);
121 viewCalculated[view] =
true;
123 catch (
const StatusCodeException &e)
125 if (e.GetStatusCode() == STATUS_CODE_NOT_FOUND)
131 driftMin = std::min(viewDriftMin, driftMin);
132 driftMax = std::max(viewDriftMax, driftMax);
137 const CaloHitList *pCaloHitList{
nullptr};
138 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_INITIALIZED, !=, PandoraContentApi::GetList(*
this, listName, pCaloHitList));
142 std::cout <<
"ERR: Could not find full CaloHitList - DlSNSignalAlgorithm unable to proceed" << std::endl;
146 HitType view{pCaloHitList->front()->GetHitType()};
147 float viewDriftMin{driftMin}, viewDriftMax{driftMax};
148 if (!viewCalculated[view])
150 const LArTPC *
const pTPC(this->GetPandora().GetGeometry()->GetLArTPCMap().
begin()->
second);
151 const float pitch(view == TPC_VIEW_U ? pTPC->GetWirePitchU()
152 : view == TPC_VIEW_V ? pTPC->GetWirePitchV()
153 : pTPC->GetWirePitchW());
154 const float zSpan{pitch * (
m_height - 1)};
155 bool projected{
false};
158 if (viewCalculated[TPC_VIEW_W] && viewCalculated[TPC_VIEW_U])
160 const float z1{(wireMax[TPC_VIEW_W] - wireMin[TPC_VIEW_W]) * 0.5
f};
161 const float z2{(wireMax[TPC_VIEW_U] - wireMin[TPC_VIEW_U]) * 0.5
f};
162 const float m_projectedCoordinate{
static_cast<float>(this->GetPandora().GetPlugins()->GetLArTransformationPlugin()->WUtoV(z1, z2))};
163 wireMin[view] = m_projectedCoordinate - zSpan;
164 wireMax[view] = m_projectedCoordinate + zSpan;
168 if (viewCalculated[TPC_VIEW_W] && viewCalculated[TPC_VIEW_V])
170 const float z1{(wireMax[TPC_VIEW_W] - wireMin[TPC_VIEW_W]) * 0.5
f};
171 const float z2{(wireMax[TPC_VIEW_V] - wireMin[TPC_VIEW_V]) * 0.5
f};
173 const float m_projectedCoordinate{
static_cast<float>(this->GetPandora().GetPlugins()->GetLArTransformationPlugin()->VWtoU(z1, z2))};
174 wireMin[view] = m_projectedCoordinate - zSpan;
175 wireMax[view] = m_projectedCoordinate + zSpan;
179 if (viewCalculated[TPC_VIEW_U] && viewCalculated[TPC_VIEW_V])
181 const float z1{(wireMax[TPC_VIEW_U] - wireMin[TPC_VIEW_U]) * 0.5
f};
182 const float z2{(wireMax[TPC_VIEW_V] - wireMin[TPC_VIEW_V]) * 0.5
f};
183 const float m_projectedCoordinate{
static_cast<float>(this->GetPandora().GetPlugins()->GetLArTransformationPlugin()->UVtoW(z1, z2))};
184 wireMin[view] = m_projectedCoordinate - zSpan;
185 wireMax[view] = m_projectedCoordinate + zSpan;
193 this->
GetHitRegion(*pCaloHitList, viewDriftMin, viewDriftMax, wireMin[view], wireMax[view]);
196 catch (
const StatusCodeException &
e)
198 if (e.GetStatusCode() == STATUS_CODE_NOT_FOUND)
200 std::cout <<
"ERR: Could not calculate zoom region - DlSNSignalAlgorithm unable to proceed" << std::endl;
206 driftMin = std::min(viewDriftMin, driftMin);
207 driftMax = std::max(viewDriftMax, driftMax);
210 for (
const std::string &listname : m_caloHitListNames)
212 const CaloHitList *pCaloHitList(
nullptr);
213 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*
this, listname, pCaloHitList));
214 if (pCaloHitList->empty())
217 HitType view{pCaloHitList->front()->GetHitType()};
218 const bool isU{view == TPC_VIEW_U}, isV{view == TPC_VIEW_V}, isW{view == TPC_VIEW_W};
219 if (!(isU || isV || isW))
220 return STATUS_CODE_NOT_ALLOWED;
223 unsigned long nHits{0};
226 double xMin{driftMin}, xMax{driftMax}, zMin{wireMin[view]}, zMax{wireMax[view]};
229 featureVector.emplace_back(xMin);
230 featureVector.emplace_back(xMax);
231 featureVector.emplace_back(zMin);
232 featureVector.emplace_back(zMax);
234 for (
const CaloHit *pCaloHit : *pCaloHitList)
236 const float x{pCaloHit->GetPositionVector().GetX()},
z{pCaloHit->GetPositionVector().GetZ()}, adc{pCaloHit->GetMipEquivalentEnergy()};
238 if (
m_pass > 1 && (x < xMin || x > xMax || z < zMin || z > zMax))
240 featureVector.emplace_back(static_cast<double>(
x));
241 featureVector.emplace_back(static_cast<double>(
z));
242 featureVector.emplace_back(static_cast<double>(adc));
245 const MCParticle *pMainMCParticle(
nullptr);
248 pMainMCParticle = MCParticleHelper::GetMainMCParticle(pCaloHit);
250 catch (
const StatusCodeException &)
256 const MCParticle *
const pParentMCParticle(LArMCParticleHelper::GetParentMCParticle(pMainMCParticle));
258 if (LArMCParticleHelper::IsNeutrino(pParentMCParticle))
260 const int pdg{pMainMCParticle->GetParticleId()};
261 if (pdg == E_MINUS || pdg == PHOTON)
262 featureVector.emplace_back(pdg);
264 featureVector.emplace_back(0);
268 featureVector.emplace_back(0);
273 featureVector.emplace_back(0);
276 featureVector.insert(featureVector.begin() + 4,
static_cast<double>(nHits));
277 LArMvaHelper::ProduceTrainingExample(trainingFilename,
true, featureVector);
280 return STATUS_CODE_SUCCESS;
290 std::map<int, float> wireMin, wireMax;
291 std::map<int, bool> viewCalculated;
292 float driftMin{std::numeric_limits<float>::max()}, driftMax{-std::numeric_limits<float>::max()};
295 const CaloHitList *pCaloHitList{
nullptr};
296 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_INITIALIZED, !=, PandoraContentApi::GetList(*
this, listname, pCaloHitList));
304 HitType view{pCaloHitList->front()->GetHitType()};
305 float viewDriftMin{driftMin}, viewDriftMax{driftMax};
308 this->
GetHitRegion(*pCaloHitList, viewDriftMin, viewDriftMax, wireMin[view], wireMax[view]);
309 viewCalculated[view] =
true;
311 catch (
const StatusCodeException &
e)
313 if (e.GetStatusCode() == STATUS_CODE_NOT_FOUND)
318 driftMin = std::min(viewDriftMin, driftMin);
319 driftMax = std::max(viewDriftMax, driftMax);
324 const CaloHitList *pCaloHitList{
nullptr};
325 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_INITIALIZED, !=, PandoraContentApi::GetList(*
this, listName, pCaloHitList));
329 std::cout <<
"ERR: Could not find full CaloHitList - DlSNSignalAlgorithm unable to proceed" << std::endl;
333 HitType view{pCaloHitList->front()->GetHitType()};
334 float viewDriftMin{driftMin}, viewDriftMax{driftMax};
335 if (viewCalculated[view] !=
true)
337 const LArTPC *
const pTPC(this->GetPandora().GetGeometry()->GetLArTPCMap().
begin()->
second);
338 const float pitch(view == TPC_VIEW_U ? pTPC->GetWirePitchU()
339 : view == TPC_VIEW_V ? pTPC->GetWirePitchV()
340 : pTPC->GetWirePitchW());
341 const float zSpan{pitch * (
m_height - 1)};
342 bool projected{
false};
345 if (viewCalculated[TPC_VIEW_W] && viewCalculated[TPC_VIEW_U])
347 const float z1{(wireMax[TPC_VIEW_W] - wireMin[TPC_VIEW_W]) * 0.5
f};
348 const float z2{(wireMax[TPC_VIEW_U] - wireMin[TPC_VIEW_U]) * 0.5
f};
349 const float m_projectedCoordinate{
static_cast<float>(this->GetPandora().GetPlugins()->GetLArTransformationPlugin()->WUtoV(z1, z2))};
350 wireMin[view] = m_projectedCoordinate - zSpan;
351 wireMax[view] = m_projectedCoordinate + zSpan;
355 if (viewCalculated[TPC_VIEW_W] && viewCalculated[TPC_VIEW_V])
357 const float z1{(wireMax[TPC_VIEW_W] - wireMin[TPC_VIEW_W]) * 0.5
f};
358 const float z2{(wireMax[TPC_VIEW_V] - wireMin[TPC_VIEW_V]) * 0.5
f};
360 const float m_projectedCoordinate{
static_cast<float>(this->GetPandora().GetPlugins()->GetLArTransformationPlugin()->VWtoU(z1, z2))};
361 wireMin[view] = m_projectedCoordinate - zSpan;
362 wireMax[view] = m_projectedCoordinate + zSpan;
366 if (viewCalculated[TPC_VIEW_U] && viewCalculated[TPC_VIEW_V])
368 const float z1{(wireMax[TPC_VIEW_U] - wireMin[TPC_VIEW_U]) * 0.5
f};
369 const float z2{(wireMax[TPC_VIEW_V] - wireMin[TPC_VIEW_V]) * 0.5
f};
370 const float m_projectedCoordinate{
static_cast<float>(this->GetPandora().GetPlugins()->GetLArTransformationPlugin()->UVtoW(z1, z2))};
371 wireMin[view] = m_projectedCoordinate - zSpan;
372 wireMax[view] = m_projectedCoordinate + zSpan;
381 this->
GetHitRegion(*pCaloHitList, viewDriftMin, viewDriftMax, wireMin[view], wireMax[view]);
384 catch (
const StatusCodeException &
e)
386 if (e.GetStatusCode() == STATUS_CODE_NOT_FOUND)
388 std::cout <<
"ERR: Could not calculate zoom region - DlSNSignalAlgorithm unable to proceed" << std::endl;
394 driftMin = std::min(viewDriftMin, driftMin);
395 driftMax = std::max(viewDriftMax, driftMax);
398 CaloHitList signalCandidatesU, signalCandidatesV, signalCandidatesW, signalCandidates2D, backgroundCaloHitList, photonCandidatesU,
399 photonCandidatesV, photonCandidatesW, electronCandidatesU, electronCandidatesV, electronCandidatesW;
400 for (
const std::string &listName : m_caloHitListNames)
402 const CaloHitList *pCaloHitList{
nullptr};
403 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_INITIALIZED, !=, PandoraContentApi::GetList(*
this, listName, pCaloHitList));
405 if (!pCaloHitList || pCaloHitList->empty())
408 HitType view{pCaloHitList->front()->GetHitType()};
409 const bool isU{view == TPC_VIEW_U}, isV{view == TPC_VIEW_V}, isW{view == TPC_VIEW_W};
410 if (!isU && !isV && !isW)
411 return STATUS_CODE_NOT_ALLOWED;
415 this->
MakeNetworkInputFromHits(*pCaloHitList, view, driftMin, driftMax, wireMin[view], wireMax[view], input, pixelMap);
419 inputs.push_back(input);
429 auto classes{torch::argmax(output, 1)};
431 auto classesAccessor{classes.accessor<int64_t, 3>()};
432 std::map<int, bool> haveSeenMap;
434 for (
const auto &[pCaloHit, pixel] : pixelMap)
437 const auto cls{classesAccessor[0][pixel.second][pixel.first]};
442 std::cout <<
"*Electron identified*" << std::endl;
446 std::cout <<
"*Photon identified*" << std::endl;
450 std::cout <<
"*Signal Pixel identified*" << std::endl;
456 signalCandidates2D.emplace_back(pCaloHit);
460 signalCandidatesU.emplace_back(pCaloHit);
461 photonCandidatesU.emplace_back(pCaloHit);
465 signalCandidatesV.emplace_back(pCaloHit);
466 photonCandidatesV.emplace_back(pCaloHit);
470 signalCandidatesW.emplace_back(pCaloHit);
471 photonCandidatesW.emplace_back(pCaloHit);
476 signalCandidates2D.emplace_back(pCaloHit);
480 signalCandidatesU.emplace_back(pCaloHit);
481 electronCandidatesU.emplace_back(pCaloHit);
485 signalCandidatesV.emplace_back(pCaloHit);
486 electronCandidatesV.emplace_back(pCaloHit);
490 signalCandidatesW.emplace_back(pCaloHit);
491 electronCandidatesW.emplace_back(pCaloHit);
496 backgroundCaloHitList.emplace_back(pCaloHit);
503 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &signalCandidatesU,
"candidate signal U", BLUE));
506 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &photonCandidatesU,
"candidate photon U", BLACK));
507 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &electronCandidatesU,
"candidate electron U", RED));
513 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &signalCandidatesV,
"candidate signal V", BLUE));
516 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &photonCandidatesV,
"candidate photon V", BLACK));
517 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &electronCandidatesV,
"candidate electron V", RED));
523 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &signalCandidatesW,
"candidate signal W", BLUE));
526 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &photonCandidatesW,
"candidate photon W", BLACK));
527 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &electronCandidatesW,
"candidate electron W", RED));
531 PANDORA_MONITORING_API(ViewEvent(this->GetPandora()));
536 PANDORA_MONITORING_API(SetEveDisplayParameters(this->GetPandora(),
true, DETECTOR_VIEW_XZ, -1.
f, 1.
f, 1.
f));
539 for (
const CaloHit *pCaloHit : *pCaloHitList)
541 const float x{pCaloHit->GetPositionVector().GetX()},
z{pCaloHit->GetPositionVector().GetZ()};
542 const MCParticle *pMainMCParticle(
nullptr);
545 pMainMCParticle = MCParticleHelper::GetMainMCParticle(pCaloHit);
547 catch (
const StatusCodeException &)
553 const MCParticle *
const pParentMCParticle(LArMCParticleHelper::GetParentMCParticle(pMainMCParticle));
555 if (LArMCParticleHelper::IsNeutrino(pParentMCParticle))
557 const CartesianVector signalHit(
x, 0.
f,
z);
558 const int pdg{pMainMCParticle->GetParticleId()};
559 std::string electronLabel{
"True electron "}, photonLabel{
"True photon "};
563 PANDORA_MONITORING_API(AddMarkerToVisualization(this->GetPandora(), &signalHit, label, YELLOW, 2));
568 PANDORA_MONITORING_API(AddMarkerToVisualization(this->GetPandora(), &signalHit, label, GREEN, 2));
574 catch (StatusCodeException &
e)
576 std::cerr <<
"DlSNSignalAlgorithm: Warning. Couldn't find signal hits." << std::endl;
578 PANDORA_MONITORING_API(ViewEvent(this->GetPandora()));
584 std::cout <<
"Printing U view candidate length: " << signalCandidatesU.size() << std::endl;
585 std::cout <<
"Printing V view candidate length: " << signalCandidatesV.size() << std::endl;
586 std::cout <<
"Printing W view candidate length: " << signalCandidatesW.size() << std::endl;
587 std::cout <<
"Printing 2D view candidate length: " << signalCandidates2D.size() << std::endl;
588 std::cout <<
"Printing background CaloHitList length: " << backgroundCaloHitList.size() << std::endl;
589 std::cout <<
"Printing New CaloHitList Names: " << std::endl;
594 if (signalCandidatesU.empty() || signalCandidatesV.empty() || signalCandidatesW.empty() || signalCandidates2D.empty())
596 std::cout <<
"Error: A CaloHitList is empty" << std::endl;
598 if (!signalCandidatesU.empty())
599 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, signalCandidatesU,
m_signalListNameU));
601 if (!signalCandidatesV.empty())
602 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, signalCandidatesV,
m_signalListNameV));
604 if (!signalCandidatesW.empty())
605 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, signalCandidatesW,
m_signalListNameW));
607 if (!signalCandidates2D.empty())
608 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, signalCandidates2D,
m_signalListName2D));
610 if (!backgroundCaloHitList.empty())
611 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, backgroundCaloHitList,
m_backgroundListName));
613 return STATUS_CODE_SUCCESS;
620 CaloHitList signalCandidatesU, signalCandidatesV, signalCandidatesW, signalCandidates2D, backgroundCaloHitList, photonCandidatesU,
621 photonCandidatesV, photonCandidatesW, electronCandidatesU, electronCandidatesV, electronCandidatesW;
624 const CaloHitList *pCaloHitList(
nullptr);
625 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*
this, listname, pCaloHitList));
626 if (!pCaloHitList || pCaloHitList->empty())
629 HitType view{pCaloHitList->front()->GetHitType()};
630 const bool isU{view == TPC_VIEW_U}, isV{view == TPC_VIEW_V}, isW{view == TPC_VIEW_W};
631 if (!isU && !isV && !isW)
632 return STATUS_CODE_NOT_ALLOWED;
634 for (
const CaloHit *pCaloHit : *pCaloHitList)
636 const MCParticle *pMainMCParticle(
nullptr);
639 pMainMCParticle = MCParticleHelper::GetMainMCParticle(pCaloHit);
641 catch (
const StatusCodeException &)
647 const MCParticle *
const pParentMCParticle(LArMCParticleHelper::GetParentMCParticle(pMainMCParticle));
649 if (LArMCParticleHelper::IsNeutrino(pParentMCParticle))
651 if (
std::abs(pMainMCParticle->GetParticleId()) == PHOTON)
653 signalCandidates2D.emplace_back(pCaloHit);
657 signalCandidatesU.emplace_back(pCaloHit);
658 photonCandidatesU.emplace_back(pCaloHit);
662 signalCandidatesV.emplace_back(pCaloHit);
663 photonCandidatesV.emplace_back(pCaloHit);
667 signalCandidatesW.emplace_back(pCaloHit);
668 photonCandidatesU.emplace_back(pCaloHit);
670 if (
std::abs(pMainMCParticle->GetParticleId()) == E_MINUS)
674 signalCandidatesU.emplace_back(pCaloHit);
675 electronCandidatesU.emplace_back(pCaloHit);
679 signalCandidatesV.emplace_back(pCaloHit);
680 electronCandidatesV.emplace_back(pCaloHit);
684 signalCandidatesW.emplace_back(pCaloHit);
685 electronCandidatesW.emplace_back(pCaloHit);
690 backgroundCaloHitList.emplace_back(pCaloHit);
695 backgroundCaloHitList.emplace_back(pCaloHit);
700 backgroundCaloHitList.emplace_back(pCaloHit);
708 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &signalCandidatesU,
"true signal U", BLUE));
709 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &photonCandidatesU,
"true photon U", BLACK));
710 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &electronCandidatesU,
"true electron U", RED));
714 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &signalCandidatesV,
"true signal V", BLUE));
715 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &photonCandidatesV,
"true photon V", BLACK));
716 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &electronCandidatesV,
"true electron V", RED));
720 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &signalCandidatesW,
"true signal W", BLUE));
721 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &photonCandidatesW,
"true photon W", BLACK));
722 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &electronCandidatesW,
"true electron W", RED));
725 PANDORA_MONITORING_API(SetEveDisplayParameters(this->GetPandora(),
true, DETECTOR_VIEW_XZ, -1.
f, 1.
f, 1.
f));
726 PANDORA_MONITORING_API(ViewEvent(this->GetPandora()));
730 if (signalCandidatesU.empty() || signalCandidatesV.empty() || signalCandidatesW.empty() || signalCandidates2D.empty())
732 std::cout <<
"Error: A CaloHitList is empty" << std::endl;
734 if (!signalCandidatesU.empty())
735 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, signalCandidatesU,
m_signalListNameU));
737 if (!signalCandidatesV.empty())
738 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, signalCandidatesV,
m_signalListNameV));
740 if (!signalCandidatesW.empty())
741 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, signalCandidatesW,
m_signalListNameW));
743 if (!signalCandidates2D.empty())
744 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, signalCandidates2D,
m_signalListName2D));
746 if (!backgroundCaloHitList.empty())
747 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::SaveList(*
this, backgroundCaloHitList,
m_backgroundListName));
749 return STATUS_CODE_SUCCESS;
758 const LArTPC *
const pTPC(this->GetPandora().GetGeometry()->GetLArTPCMap().
begin()->
second);
759 const float pitch(view == TPC_VIEW_U ? pTPC->GetWirePitchU() : view == TPC_VIEW_V ? pTPC->GetWirePitchV() : pTPC->GetWirePitchW());
762 std::vector<double> xBinEdges(
m_width + 1);
763 std::vector<double> zBinEdges(
m_height + 1);
766 for (
int i = 1; i <
m_width + 1; ++i)
767 xBinEdges[i] = xBinEdges[i - 1] + dx;
768 zBinEdges[0] = zMin - 0.5f * pitch;
769 const double dz = ((zMax + 0.5f * pitch) - zBinEdges[0]) /
m_height;
770 for (
int i = 1; i <
m_height + 1; ++i)
771 zBinEdges[i] = zBinEdges[i - 1] + dz;
774 auto accessor = networkInput.accessor<float, 4>();
776 for (
const CaloHit *pCaloHit : caloHits)
778 const float x{pCaloHit->GetPositionVector().GetX()};
779 const float z{pCaloHit->GetPositionVector().GetZ()};
782 if (x < xMin || x > xMax || z < zMin || z > zMax)
785 const float adc{pCaloHit->GetMipEquivalentEnergy()};
786 const int pixelX{
static_cast<int>(std::floor((
x - xBinEdges[0]) / dx))};
787 const int pixelZ{
static_cast<int>(std::floor((
z - zBinEdges[0]) / dz))};
788 accessor[0][0][pixelZ][pixelX] += adc;
789 pixelMap[pCaloHit] = std::make_pair(pixelX, pixelZ);
792 return STATUS_CODE_SUCCESS;
799 const CaloHitList *pCaloHitList2D(
nullptr);
800 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*
this,
m_caloHitListName2D, pCaloHitList2D));
801 const MCParticleList *pMCParticleList(
nullptr);
802 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetCurrentList(*
this, pMCParticleList));
803 if (pMCParticleList->empty() || pCaloHitList2D->empty())
804 throw StatusCodeException(STATUS_CODE_NOT_ALLOWED);
811 LArMCParticleHelper::SelectReconstructableMCParticles(
812 pMCParticleList, pCaloHitList2D, parameters, LArMCParticleHelper::IsBeamNeutrinoFinalState, mcToHitsMap);
814 return STATUS_CODE_SUCCESS;
823 for (
const auto &[mc,
hits] : mcToHitsMap)
826 mcHierarchy.push_back(mc);
827 LArMCParticleHelper::GetAllAncestorMCParticles(mc, mcHierarchy);
830 catch (
const StatusCodeException &
e)
832 return e.GetStatusCode();
837 std::find_if(mcHierarchy.begin(), mcHierarchy.end(), [](
const MCParticle *mc) ->
bool {
return LArMCParticleHelper::IsNeutrino(mc); });
838 if (pivot != mcHierarchy.end())
839 std::rotate(mcHierarchy.begin(), pivot, std::next(pivot));
841 return STATUS_CODE_NOT_FOUND;
843 return STATUS_CODE_SUCCESS;
852 xMin = std::numeric_limits<float>::max();
853 xMax = -std::numeric_limits<float>::max();
854 zMin = std::numeric_limits<float>::max();
855 zMax = -std::numeric_limits<float>::max();
858 if (caloHitList.empty())
859 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
861 for (
const CaloHit *pCaloHit : caloHitList)
863 const float x{pCaloHit->GetPositionVector().GetX()};
864 const float z{pCaloHit->GetPositionVector().GetZ()};
865 xMin = std::min(
x, xMin);
866 xMax = std::max(
x, xMax);
867 zMin = std::min(
z, zMin);
868 zMax = std::max(
z, zMax);
870 HitType view{caloHitList.front()->GetHitType()};
871 const bool isU{view == TPC_VIEW_U}, isV{view == TPC_VIEW_V}, isW{view == TPC_VIEW_W};
872 if (!(isU || isV || isW))
873 throw StatusCodeException(STATUS_CODE_NOT_ALLOWED);
876 const LArTPC *
const pTPC(this->GetPandora().GetGeometry()->GetLArTPCMap().
begin()->
second);
877 const float pitch(view == TPC_VIEW_U ? pTPC->GetWirePitchU() : view == TPC_VIEW_V ? pTPC->GetWirePitchV() : pTPC->GetWirePitchW());
881 float xSum{0.f}, zSum{0.f}, nSum{0.f};
883 for (
const CaloHit *pCaloHit : caloHitList)
885 const float x{pCaloHit->GetPositionVector().GetX()},
z{pCaloHit->GetPositionVector().GetZ()};
891 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
892 const CartesianVector ¢re{xSum / nSum, 0.f, zSum / nSum};
895 int nHitsLeft{0}, nHitsRight{0};
898 int nHitsUpstream{0}, nHitsDownstream{0};
900 const float xCtr{centre.GetX()};
901 const float zCtr{centre.GetZ()};
903 for (
const CaloHit *pCaloHit : caloHitList)
905 const float x{pCaloHit->GetPositionVector().GetX()},
z{pCaloHit->GetPositionVector().GetZ()};
917 const int nHitsTotal{nHitsLeft + nHitsRight};
919 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
920 const float xAsymmetry{nHitsLeft /
static_cast<float>(nHitsTotal)};
922 const int nHitsViewTotal{nHitsUpstream + nHitsDownstream};
923 if (nHitsViewTotal == 0)
924 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
925 const float zAsymmetry{nHitsUpstream /
static_cast<float>(nHitsViewTotal)};
928 xMin = xCtr - xAsymmetry * xSpan;
930 const float zSpan{pitch * (
m_height - 1)};
931 zMin = zCtr - zAsymmetry * zSpan;
937 float xPos{-std::numeric_limits<float>::max()}, zPos{-std::numeric_limits<float>::max()}, adcMax{0.f};
939 for (
const CaloHit *pCaloHit : caloHitList)
941 const float xC{pCaloHit->GetPositionVector().GetX()}, zC{pCaloHit->GetPositionVector().GetZ()}, adc{pCaloHit->GetMipEquivalentEnergy()};
943 if (xC >= xMin && xC < xMax)
955 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
960 const float zSpan{pitch * (
m_height - 1)};
968 const float xRange{xMax - xMin}, zRange{zMax - zMin};
970 if (xRange < minXSpan)
972 const float padding{0.5f * (minXSpan - xRange)};
976 const float minZSpan{pitch * (
m_height - 1)};
977 if (zRange < minZSpan)
979 const float padding{0.5f * (minZSpan - zRange)};
989 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"TrainingMode",
m_trainingMode));
990 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"Visualise",
m_visualise));
991 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"Pass",
m_pass));
992 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ImageHeight",
m_height));
993 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ImageWidth",
m_width));
994 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"DriftStep",
m_driftStep));
995 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"SignalListNameU",
m_signalListNameU));
996 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"SignalListNameV",
m_signalListNameV));
997 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"SignalListNameW",
m_signalListNameW));
998 PANDORA_RETURN_RESULT_IF_AND_IF(
999 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"SignalListName2D",
m_signalListName2D));
1003 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"TrainingOutputFileName",
m_trainingOutputFile));
1007 std::string modelName;
1008 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"ModelFileNameU", modelName));
1009 modelName = LArFileHelper::FindFileInPath(modelName,
"FW_SEARCH_PATH");
1011 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"ModelFileNameV", modelName));
1012 modelName = LArFileHelper::FindFileInPath(modelName,
"FW_SEARCH_PATH");
1014 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"ModelFileNameW", modelName));
1015 modelName = LArFileHelper::FindFileInPath(modelName,
"FW_SEARCH_PATH");
1017 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"WriteTree",
m_writeTree));
1020 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"RootTreeName",
m_rootTreeName));
1021 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"RootFileName",
m_rootFileName));
1025 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=,
1027 PANDORA_RETURN_RESULT_IF_AND_IF(
1028 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(xmlHandle,
"CaloHitListNames",
m_caloHitListNames));
1029 PANDORA_RETURN_RESULT_IF_AND_IF(
1030 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"CaloHitListName2D",
m_caloHitListName2D));
1031 PANDORA_RETURN_RESULT_IF_AND_IF(
1032 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"PassOneTrustThreshold",
m_passOneTrustThreshold));
1033 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"PrintOut",
m_printOut));
1034 PANDORA_RETURN_RESULT_IF_AND_IF(
1035 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"BackgroundListName",
m_backgroundListName));
1036 PANDORA_RETURN_RESULT_IF_AND_IF(
1037 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ApplyCheatedSeparation",
m_applyCheatedSeparation));
1039 return STATUS_CODE_SUCCESS;
pandora::StatusCode Run()
std::string m_signalListNameV
Output signal CaloHitListV name.
std::string m_rootFileName
The ROOT file name.
unsigned int m_minPrimaryGoodViews
the minimum number of primary good views
pandora::StatusCode CompleteMCHierarchy(const LArMCParticleHelper::MCContributionMap &mcToHitsMap, pandora::MCParticleList &mcHierarchy) const
std::unordered_map< const pandora::MCParticle *, pandora::CaloHitList > MCContributionMap
MvaTypes::MvaFeatureVector MvaFeatureVector
const int ELECTRON_CLASS
Constant for network classification for electrons.
int m_width
The width of the images.
pandora::StatusCode MakeNetworkInputFromHits(const pandora::CaloHitList &caloHits, const pandora::HitType view, const float xMin, const float xMax, const float zMin, const float zMax, LArDLHelper::TorchInput &networkInput, PixelMap &pixelMap) const
unsigned int m_minPrimaryGoodHits
the minimum number of primary good Hits
std::string m_caloHitListName2D
Input CaloHitList2D name.
int m_height
The height of the images.
constexpr auto abs(T v)
Returns the absolute value of the argument.
const int PHOTON_CLASS
Constant for network classification for photons.
pandora::StringVector m_caloHitListNames
Names of input calo hit lists.
std::string m_signalListNameW
Output signal CaloHitListW name.
bool m_trainingMode
Training mode.
int m_pass
The pass of the train/infer step.
unsigned int m_minHitsForGoodView
the minimum number of Hits for a good view
LArDLHelper::TorchModel m_modelU
The model for the U view.
float m_maxPhotonPropagation
the maximum photon propagation length
LArDLHelper::TorchModel m_modelW
The model for the W view.
std::string m_trainingOutputFile
Output file name for training examples.
bool m_visualise
Whether or not to visualise the candidate vertices.
Header file for the geometry helper class.
std::string m_rootTreeName
The ROOT tree name.
Header file for the lar monte carlo particle helper helper class.
decltype(auto) constexpr to_string(T &&obj)
ADL-aware version of std::to_string.
std::string m_backgroundListName
Input Background CaloHitList name.
pandora::StatusCode CheatedSeparation()
pandora::StatusCode PrepareTrainingSample()
Header file for the file helper class.
static void Forward(TorchModel &model, const TorchInputVector &input, TorchOutput &output)
Run a deep learning model.
long unsigned int m_passOneTrustThreshold
Number of pixels in pass one required to trust the wire finding ability, below this threshold...
pandora::StringVector m_inputCaloHitListNames
Names of input calo hit lists, passed from Pass 1 of DLSignalAlg.
Header file for the vertex helper class.
int m_event
The current event number.
std::string m_signalListName2D
Output signal CaloHitList2D name.
const int SIGNAL_CLASS
Constant for network classification for signal.
bool m_applyCheatedSeparation
Whether cheating to separate background and signal hits.
static pandora::StatusCode LoadModel(const std::string &filename, TorchModel &model)
Loads a deep learning model.
LArDLHelper::TorchModel m_modelV
The model for the V view.
float m_driftStep
The size of a pixel in the drift direction in cm (most relevant in pass 2)
bool m_printOut
Whether or not to print out network outputs of CaloHitList names and sizes.
bool m_simpleZoom
Decide whethere to run a simple loop to find highest adc hit or run network.
std::map< const pandora::CaloHit *, Pixel > PixelMap
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
bool m_writeTree
Whether or not to write validation details to a ROOT tree.
virtual ~DlSNSignalAlgorithm()
std::string m_signalListNameU
Output signal CaloHitListU name.
pandora::StatusCode Infer()
pandora::StatusCode GetMCToHitsMap(LArMCParticleHelper::MCContributionMap &mcToHitsMap) const
pandora::StatusCode ReadSettings(const pandora::TiXmlHandle xmlHandle)
second_as<> second
Type of time stored in seconds, in double precision.
static void InitialiseInput(const at::IntArrayRef dimensions, TorchInput &tensor)
Create a torch input tensor.
void GetHitRegion(const pandora::CaloHitList &caloHitList, float &xMin, float &xMax, float &zMin, float &zMax) const
std::vector< torch::jit::IValue > TorchInputVector