9 #include "Helpers/XmlHelper.h" 18 AdaBoostDecisionTree::AdaBoostDecisionTree() :
19 m_pStrongClassifier(nullptr)
53 std::cout <<
"AdaBoostDecisionTree: AdaBoostDecisionTree was already initialized" << std::endl;
54 return STATUS_CODE_ALREADY_INITIALIZED;
57 TiXmlDocument xmlDocument(bdtXmlFileName);
59 if (!xmlDocument.LoadFile())
61 std::cout <<
"AdaBoostDecisionTree::Initialize - Invalid xml file." << std::endl;
62 return STATUS_CODE_INVALID_PARAMETER;
65 const TiXmlHandle xmlDocumentHandle(&xmlDocument);
66 TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
68 while (pContainerXmlNode)
70 if (pContainerXmlNode->ValueStr() !=
"AdaBoostDecisionTree")
71 return STATUS_CODE_FAILURE;
73 const TiXmlHandle currentHandle(pContainerXmlNode);
75 std::string currentName;
76 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle,
"Name", currentName));
78 if (currentName.empty() || (currentName.size() > 1000))
80 std::cout <<
"AdaBoostDecisionTree::Initialize - Implausible AdaBoostDecisionTree name extracted from xml." << std::endl;
81 return STATUS_CODE_INVALID_PARAMETER;
84 if (currentName == bdtName)
87 pContainerXmlNode = pContainerXmlNode->NextSibling();
90 if (!pContainerXmlNode)
92 std::cout <<
"AdaBoostDecisionTree: Could not find an AdaBoostDecisionTree of name " << bdtName << std::endl;
93 return STATUS_CODE_NOT_FOUND;
96 const TiXmlHandle xmlHandle(pContainerXmlNode);
102 catch (StatusCodeException &statusCodeException)
106 if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
107 std::cout <<
"AdaBoostDecisionTree: Initialization failure, unknown component in xml file." << std::endl;
109 if (STATUS_CODE_FAILURE == statusCodeException.GetStatusCode())
110 std::cout <<
"AdaBoostDecisionTree: Node definition does not contain expected leaf or branch variables." << std::endl;
112 return statusCodeException.GetStatusCode();
115 return STATUS_CODE_SUCCESS;
147 std::cout <<
"AdaBoostDecisionTree: Attempting to use an uninitialized bdt" << std::endl;
148 throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
156 catch (StatusCodeException &statusCodeException)
158 if (STATUS_CODE_NOT_FOUND == statusCodeException.GetStatusCode())
160 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when trying to cut on an unknown variable." << std::endl;
162 else if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
164 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when classifier weights sum to zero indicating defunct classifier." 167 else if (STATUS_CODE_OUT_OF_RANGE == statusCodeException.GetStatusCode())
169 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when heirarchy in decision tree is incomplete." << std::endl;
173 std::cout <<
"AdaBoostDecisionTree: Unexpected exception thrown." << std::endl;
176 throw statusCodeException;
186 m_leftChildNodeId(0),
187 m_rightChildNodeId(0),
193 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"NodeId", m_nodeId));
194 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"ParentNodeId", m_parentNodeId));
196 const StatusCode leftChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"LeftChildNodeId", m_leftChildNodeId));
197 const StatusCode rightChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"RightChildNodeId", m_rightChildNodeId));
198 const StatusCode thresholdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"Threshold", m_threshold));
199 const StatusCode variableIdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"VariableId", m_variableId));
200 const StatusCode outcomeStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"Outcome", m_outcome));
202 if (STATUS_CODE_SUCCESS == leftChildNodeIdStatusCode || STATUS_CODE_SUCCESS == rightChildNodeIdStatusCode ||
203 STATUS_CODE_SUCCESS == thresholdStatusCode || STATUS_CODE_SUCCESS == variableIdStatusCode)
208 else if (outcomeStatusCode == STATUS_CODE_SUCCESS)
211 m_leftChildNodeId = std::numeric_limits<int>::max();
212 m_rightChildNodeId = std::numeric_limits<int>::max();
213 m_threshold = std::numeric_limits<double>::max();
214 m_variableId = std::numeric_limits<int>::max();
218 throw StatusCodeException(STATUS_CODE_FAILURE);
225 m_nodeId(rhs.m_nodeId),
226 m_parentNodeId(rhs.m_parentNodeId),
227 m_leftChildNodeId(rhs.m_leftChildNodeId),
228 m_rightChildNodeId(rhs.m_rightChildNodeId),
229 m_isLeaf(rhs.m_isLeaf),
230 m_threshold(rhs.m_threshold),
231 m_variableId(rhs.m_variableId),
232 m_outcome(rhs.m_outcome)
268 for (TiXmlElement *pHeadTiXmlElement = pXmlHandle->FirstChildElement().ToElement(); pHeadTiXmlElement != NULL;
269 pHeadTiXmlElement = pHeadTiXmlElement->NextSiblingElement())
271 if (
"TreeIndex" == pHeadTiXmlElement->ValueStr())
273 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"TreeIndex", m_treeId));
275 else if (
"TreeWeight" == pHeadTiXmlElement->ValueStr())
277 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"TreeWeight", m_weight));
279 else if (
"Node" == pHeadTiXmlElement->ValueStr())
281 const TiXmlHandle nodeHandle(pHeadTiXmlElement);
282 const Node *pNode =
new Node(&nodeHandle);
283 m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->
GetNodeId(), pNode));
291 m_weight(rhs.m_weight),
292 m_treeId(rhs.m_treeId)
296 const Node *pNode =
new Node(*(mapEntry.second));
309 const Node *pNode =
new Node(*(mapEntry.second));
325 delete mapEntry.second;
339 const Node *pActiveNode(
nullptr);
347 throw StatusCodeException(STATUS_CODE_OUT_OF_RANGE);
350 if (pActiveNode->
IsLeaf())
353 if (static_cast<int>(features.size()) <= pActiveNode->
GetVariableId())
354 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
371 TiXmlElement *pCurrentXmlElement = pXmlHandle->FirstChild().Element();
373 while (pCurrentXmlElement)
375 if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
376 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
378 pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
387 m_weakClassifiers.emplace_back(
new WeakClassifier(*pWeakClassifier));
397 m_weakClassifiers.emplace_back(
new WeakClassifier(*pWeakClassifier));
407 for (
const WeakClassifier *
const pWeakClassifier : m_weakClassifiers)
408 delete pWeakClassifier;
415 double score(0.), weights(0.);
417 for (
const WeakClassifier *
const pWeakClassifier : m_weakClassifiers)
419 weights += pWeakClassifier->GetWeight();
421 if (pWeakClassifier->Predict(features))
423 score += pWeakClassifier->GetWeight();
427 score -= pWeakClassifier->GetWeight();
431 if (weights > std::numeric_limits<double>::epsilon())
437 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
447 const std::string componentName(pCurrentXmlElement->ValueStr());
448 TiXmlHandle currentHandle(pCurrentXmlElement);
450 if ((std::string(
"Name") == componentName) || (std::string(
"Timestamp") == componentName))
451 return STATUS_CODE_SUCCESS;
453 if (std::string(
"DecisionTree") == componentName)
455 m_weakClassifiers.emplace_back(
new WeakClassifier(¤tHandle));
456 return STATUS_CODE_SUCCESS;
459 return STATUS_CODE_INVALID_PARAMETER;
int m_rightChildNodeId
Right child node id.
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
int GetLeftChildNodeId() const
Return left child node id.
int GetVariableId() const
Return cut variable.
MvaTypes::MvaFeatureVector MvaFeatureVector
IdToNodeMap m_idToNodeMap
Decision tree nodes.
bool GetOutcome() const
Return outcome.
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
cout<< "Opened file "<< fin<< " ixs= "<< ixs<< endl;if(ixs==0) hhh=(TH1F *) fff-> Get("h1")
WeakClassifier class containing a decision tree and a weight.
int m_treeId
Decision tree id.
double m_threshold
Threshold used for decision if decision node.
int GetNodeId() const
Return node id.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
AdaBoostDecisionTree class.
int m_parentNodeId
Parent node id.
double GetThreshold() const
Return node threshold.
bool m_outcome
Outcome if leaf node.
StrongClassifier class used in application of adaptive boost decision tree.
Header file for the lar adaptive boosted decision tree class.
Node & operator=(const Node &rhs)
Assignment operator.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
int m_leftChildNodeId
Left child node id.
~WeakClassifier()
Destructor.
Node class used for representing a decision tree.
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model...
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetRightChildNodeId() const
Return right child node id.
~AdaBoostDecisionTree()
Destructor.
AdaBoostDecisionTree()
Constructor.
double m_weight
Boost weight.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model...
pandora::StatusCode Initialize(const std::string ¶meterLocation, const std::string &bdtName)
Initialize the bdt model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
int m_variableId
Variable cut on for decision if decision node.
bool IsLeaf() const
Return is the node a leaf.
bool m_isLeaf
Is node a leaf.
~StrongClassifier()
Destructor.