9 #include "Helpers/XmlHelper.h" 18 AdaBoostDecisionTree::AdaBoostDecisionTree() :
19 m_pStrongClassifier(nullptr)
53 std::cout <<
"AdaBoostDecisionTree: AdaBoostDecisionTree was already initialized" << std::endl;
54 return STATUS_CODE_ALREADY_INITIALIZED;
57 TiXmlDocument xmlDocument(bdtXmlFileName);
59 if (!xmlDocument.LoadFile())
61 std::cout <<
"AdaBoostDecisionTree::Initialize - Invalid xml file." << std::endl;
62 return STATUS_CODE_INVALID_PARAMETER;
65 const TiXmlHandle xmlDocumentHandle(&xmlDocument);
66 TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
68 while (pContainerXmlNode)
70 if (pContainerXmlNode->ValueStr() !=
"AdaBoostDecisionTree")
71 return STATUS_CODE_FAILURE;
73 const TiXmlHandle currentHandle(pContainerXmlNode);
75 std::string currentName;
76 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle,
"Name", currentName));
78 if (currentName.empty() || (currentName.size() > 1000))
80 std::cout <<
"AdaBoostDecisionTree::Initialize - Implausible AdaBoostDecisionTree name extracted from xml." << std::endl;
81 return STATUS_CODE_INVALID_PARAMETER;
84 if (currentName == bdtName)
87 pContainerXmlNode = pContainerXmlNode->NextSibling();
90 if (!pContainerXmlNode)
92 std::cout <<
"AdaBoostDecisionTree: Could not find an AdaBoostDecisionTree of name " << bdtName << std::endl;
93 return STATUS_CODE_NOT_FOUND;
96 const TiXmlHandle xmlHandle(pContainerXmlNode);
102 catch (StatusCodeException &statusCodeException)
106 if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
107 std::cout <<
"AdaBoostDecisionTree: Initialization failure, unknown component in xml file." << std::endl;
109 if (STATUS_CODE_FAILURE == statusCodeException.GetStatusCode())
110 std::cout <<
"AdaBoostDecisionTree: Node definition does not contain expected leaf or branch variables." << std::endl;
112 return statusCodeException.GetStatusCode();
115 return STATUS_CODE_SUCCESS;
147 std::cout <<
"AdaBoostDecisionTree: Attempting to use an uninitialized bdt" << std::endl;
148 throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
156 catch (StatusCodeException &statusCodeException)
158 if (STATUS_CODE_NOT_FOUND == statusCodeException.GetStatusCode())
160 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when trying to cut on an unknown variable." << std::endl;
162 else if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
164 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when classifier weights sum to zero indicating defunct classifier." << std::endl;
166 else if (STATUS_CODE_OUT_OF_RANGE == statusCodeException.GetStatusCode())
168 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when heirarchy in decision tree is incomplete." << std::endl;
172 std::cout <<
"AdaBoostDecisionTree: Unexpected exception thrown." << std::endl;
175 throw statusCodeException;
185 m_leftChildNodeId(0),
186 m_rightChildNodeId(0),
192 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"NodeId", m_nodeId));
193 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"ParentNodeId", m_parentNodeId));
195 const StatusCode leftChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"LeftChildNodeId", m_leftChildNodeId));
196 const StatusCode rightChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"RightChildNodeId", m_rightChildNodeId));
197 const StatusCode thresholdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"Threshold", m_threshold));
198 const StatusCode variableIdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"VariableId", m_variableId));
199 const StatusCode outcomeStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"Outcome", m_outcome));
201 if (STATUS_CODE_SUCCESS == leftChildNodeIdStatusCode || STATUS_CODE_SUCCESS == rightChildNodeIdStatusCode ||
202 STATUS_CODE_SUCCESS == thresholdStatusCode || STATUS_CODE_SUCCESS == variableIdStatusCode)
207 else if (outcomeStatusCode == STATUS_CODE_SUCCESS)
217 throw StatusCodeException(STATUS_CODE_FAILURE);
224 m_nodeId(rhs.m_nodeId),
225 m_parentNodeId(rhs.m_parentNodeId),
226 m_leftChildNodeId(rhs.m_leftChildNodeId),
227 m_rightChildNodeId(rhs.m_rightChildNodeId),
228 m_isLeaf(rhs.m_isLeaf),
229 m_threshold(rhs.m_threshold),
230 m_variableId(rhs.m_variableId),
231 m_outcome(rhs.m_outcome)
267 for (TiXmlElement *pHeadTiXmlElement = pXmlHandle->FirstChildElement().ToElement(); pHeadTiXmlElement != NULL; pHeadTiXmlElement = pHeadTiXmlElement->NextSiblingElement())
269 if (
"TreeIndex" == pHeadTiXmlElement->ValueStr())
271 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"TreeIndex", m_treeId));
273 else if (
"TreeWeight" == pHeadTiXmlElement->ValueStr())
275 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"TreeWeight", m_weight));
277 else if (
"Node" == pHeadTiXmlElement->ValueStr())
279 const TiXmlHandle nodeHandle(pHeadTiXmlElement);
280 const Node *pNode =
new Node(&nodeHandle);
281 m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->
GetNodeId(), pNode));
289 m_weight(rhs.m_weight),
290 m_treeId(rhs.m_treeId)
294 const Node *pNode =
new Node(*(mapEntry.second));
307 const Node *pNode =
new Node(*(mapEntry.second));
323 delete mapEntry.second;
337 const Node *pActiveNode(
nullptr);
345 throw StatusCodeException(STATUS_CODE_OUT_OF_RANGE);
348 if (pActiveNode->
IsLeaf())
351 if (static_cast<int>(features.size()) <= pActiveNode->
GetVariableId())
352 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
369 TiXmlElement *pCurrentXmlElement = pXmlHandle->FirstChild().Element();
371 while (pCurrentXmlElement)
373 if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
374 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
376 pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
385 m_weakClassifiers.emplace_back(
new WeakClassifier(*pWeakClassifier));
395 m_weakClassifiers.emplace_back(
new WeakClassifier(*pWeakClassifier));
405 for (
const WeakClassifier *
const pWeakClassifier : m_weakClassifiers)
406 delete pWeakClassifier;
413 double score(0.), weights(0.);
415 for (
const WeakClassifier *
const pWeakClassifier : m_weakClassifiers)
417 weights += pWeakClassifier->GetWeight();
419 if (pWeakClassifier->Predict(features))
421 score += pWeakClassifier->GetWeight();
425 score -= pWeakClassifier->GetWeight();
429 if (weights > std::numeric_limits<double>::epsilon())
435 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
445 const std::string componentName(pCurrentXmlElement->ValueStr());
446 TiXmlHandle currentHandle(pCurrentXmlElement);
448 if ((std::string(
"Name") == componentName) || (std::string(
"Timestamp") == componentName))
449 return STATUS_CODE_SUCCESS;
451 if (std::string(
"DecisionTree") == componentName)
453 m_weakClassifiers.emplace_back(
new WeakClassifier(¤tHandle));
454 return STATUS_CODE_SUCCESS;
457 return STATUS_CODE_INVALID_PARAMETER;
int m_rightChildNodeId
Right child node id.
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
int GetLeftChildNodeId() const
Return left child node id.
int GetVariableId() const
Return cut variable.
MvaTypes::MvaFeatureVector MvaFeatureVector
IdToNodeMap m_idToNodeMap
Decision tree nodes.
bool GetOutcome() const
Return outcome.
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
WeakClassifier class containing a decision tree and a weight.
int m_treeId
Decision tree id.
double m_threshold
Threshold used for decision if decision node.
int GetNodeId() const
Return node id.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
AdaBoostDecisionTree class.
cout<< "Opened file "<< fin<< " ixs= "<< ixs<< endl;if(ixs==0) hhh=(TH1F *) fff-> Get("h1")
int m_parentNodeId
Parent node id.
double GetThreshold() const
Return node threshold.
bool m_outcome
Outcome if leaf node.
StrongClassifier class used in application of adaptive boost decision tree.
Header file for the lar adaptive boosted decision tree class.
Node & operator=(const Node &rhs)
Assignment operator.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
int m_leftChildNodeId
Left child node id.
~WeakClassifier()
Destructor.
Node class used for representing a decision tree.
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model...
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetRightChildNodeId() const
Return right child node id.
~AdaBoostDecisionTree()
Destructor.
AdaBoostDecisionTree()
Constructor.
double m_weight
Boost weight.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model...
pandora::StatusCode Initialize(const std::string ¶meterLocation, const std::string &bdtName)
Initialize the bdt model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
int m_variableId
Variable cut on for decision if decision node.
bool IsLeaf() const
Return is the node a leaf.
bool m_isLeaf
Is node a leaf.
~StrongClassifier()
Destructor.