LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
LArAdaBoostDecisionTree.cc
Go to the documentation of this file.
1 
9 #include "Helpers/XmlHelper.h"
10 
12 
13 using namespace pandora;
14 
15 namespace lar_content
16 {
17 
18 AdaBoostDecisionTree::AdaBoostDecisionTree() :
19  m_pStrongClassifier(nullptr)
20 {
21 }
22 
23 //------------------------------------------------------------------------------------------------------------------------------------------
24 
26 {
28 }
29 
30 //------------------------------------------------------------------------------------------------------------------------------------------
31 
33 {
34  if (this != &rhs)
36 
37  return *this;
38 }
39 
40 //------------------------------------------------------------------------------------------------------------------------------------------
41 
43 {
44  delete m_pStrongClassifier;
45 }
46 
47 //------------------------------------------------------------------------------------------------------------------------------------------
48 
49 StatusCode AdaBoostDecisionTree::Initialize(const std::string &bdtXmlFileName, const std::string &bdtName)
50 {
52  {
53  std::cout << "AdaBoostDecisionTree: AdaBoostDecisionTree was already initialized" << std::endl;
54  return STATUS_CODE_ALREADY_INITIALIZED;
55  }
56 
57  TiXmlDocument xmlDocument(bdtXmlFileName);
58 
59  if (!xmlDocument.LoadFile())
60  {
61  std::cout << "AdaBoostDecisionTree::Initialize - Invalid xml file." << std::endl;
62  return STATUS_CODE_INVALID_PARAMETER;
63  }
64 
65  const TiXmlHandle xmlDocumentHandle(&xmlDocument);
66  TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
67 
68  while (pContainerXmlNode)
69  {
70  if (pContainerXmlNode->ValueStr() != "AdaBoostDecisionTree")
71  return STATUS_CODE_FAILURE;
72 
73  const TiXmlHandle currentHandle(pContainerXmlNode);
74 
75  std::string currentName;
76  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle, "Name", currentName));
77 
78  if (currentName.empty() || (currentName.size() > 1000))
79  {
80  std::cout << "AdaBoostDecisionTree::Initialize - Implausible AdaBoostDecisionTree name extracted from xml." << std::endl;
81  return STATUS_CODE_INVALID_PARAMETER;
82  }
83 
84  if (currentName == bdtName)
85  break;
86 
87  pContainerXmlNode = pContainerXmlNode->NextSibling();
88  }
89 
90  if (!pContainerXmlNode)
91  {
92  std::cout << "AdaBoostDecisionTree: Could not find an AdaBoostDecisionTree of name " << bdtName << std::endl;
93  return STATUS_CODE_NOT_FOUND;
94  }
95 
96  const TiXmlHandle xmlHandle(pContainerXmlNode);
97 
98  try
99  {
100  m_pStrongClassifier = new StrongClassifier(&xmlHandle);
101  }
102  catch (StatusCodeException &statusCodeException)
103  {
104  delete m_pStrongClassifier;
105 
106  if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
107  std::cout << "AdaBoostDecisionTree: Initialization failure, unknown component in xml file." << std::endl;
108 
109  if (STATUS_CODE_FAILURE == statusCodeException.GetStatusCode())
110  std::cout << "AdaBoostDecisionTree: Node definition does not contain expected leaf or branch variables." << std::endl;
111 
112  return statusCodeException.GetStatusCode();
113  }
114 
115  return STATUS_CODE_SUCCESS;
116 }
117 
118 //------------------------------------------------------------------------------------------------------------------------------------------
119 
121 {
122  return ((this->CalculateScore(features) > 0.) ? true : false);
123 }
124 
125 //------------------------------------------------------------------------------------------------------------------------------------------
126 
128 {
129  return this->CalculateScore(features);
130 }
131 
132 //------------------------------------------------------------------------------------------------------------------------------------------
133 
135 {
136  // ATTN: BDT score, once normalised by total weight, is confined to the range -1 to +1. This linear mapping places the score in the
137  // range 0 to 1 so that it may be interpreted as a probability.
138  return (this->CalculateScore(features) + 1.) * 0.5;
139 }
140 
141 //------------------------------------------------------------------------------------------------------------------------------------------
142 
144 {
145  if (!m_pStrongClassifier)
146  {
147  std::cout << "AdaBoostDecisionTree: Attempting to use an uninitialized bdt" << std::endl;
148  throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
149  }
150 
151  try
152  {
153  // TODO: Add consistency check for number of features, bearing in mind not all features in a bdt may be used
154  return m_pStrongClassifier->Predict(features);
155  }
156  catch (StatusCodeException &statusCodeException)
157  {
158  if (STATUS_CODE_NOT_FOUND == statusCodeException.GetStatusCode())
159  {
160  std::cout << "AdaBoostDecisionTree: Caught exception thrown when trying to cut on an unknown variable." << std::endl;
161  }
162  else if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
163  {
164  std::cout << "AdaBoostDecisionTree: Caught exception thrown when classifier weights sum to zero indicating defunct classifier."
165  << std::endl;
166  }
167  else if (STATUS_CODE_OUT_OF_RANGE == statusCodeException.GetStatusCode())
168  {
169  std::cout << "AdaBoostDecisionTree: Caught exception thrown when heirarchy in decision tree is incomplete." << std::endl;
170  }
171  else
172  {
173  std::cout << "AdaBoostDecisionTree: Unexpected exception thrown." << std::endl;
174  }
175 
176  throw statusCodeException;
177  }
178 }
179 
180 //------------------------------------------------------------------------------------------------------------------------------------------
181 //------------------------------------------------------------------------------------------------------------------------------------------
182 
183 AdaBoostDecisionTree::Node::Node(const TiXmlHandle *const pXmlHandle) :
184  m_nodeId(0),
185  m_parentNodeId(0),
186  m_leftChildNodeId(0),
187  m_rightChildNodeId(0),
188  m_isLeaf(false),
189  m_threshold(0.),
190  m_variableId(0),
191  m_outcome(false)
192 {
193  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "NodeId", m_nodeId));
194  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "ParentNodeId", m_parentNodeId));
195 
196  const StatusCode leftChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "LeftChildNodeId", m_leftChildNodeId));
197  const StatusCode rightChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "RightChildNodeId", m_rightChildNodeId));
198  const StatusCode thresholdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Threshold", m_threshold));
199  const StatusCode variableIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "VariableId", m_variableId));
200  const StatusCode outcomeStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Outcome", m_outcome));
201 
202  if (STATUS_CODE_SUCCESS == leftChildNodeIdStatusCode || STATUS_CODE_SUCCESS == rightChildNodeIdStatusCode ||
203  STATUS_CODE_SUCCESS == thresholdStatusCode || STATUS_CODE_SUCCESS == variableIdStatusCode)
204  {
205  m_isLeaf = false;
206  m_outcome = false;
207  }
208  else if (outcomeStatusCode == STATUS_CODE_SUCCESS)
209  {
210  m_isLeaf = true;
211  m_leftChildNodeId = std::numeric_limits<int>::max();
212  m_rightChildNodeId = std::numeric_limits<int>::max();
213  m_threshold = std::numeric_limits<double>::max();
214  m_variableId = std::numeric_limits<int>::max();
215  }
216  else
217  {
218  throw StatusCodeException(STATUS_CODE_FAILURE);
219  }
220 }
221 
222 //------------------------------------------------------------------------------------------------------------------------------------------
223 
225  m_nodeId(rhs.m_nodeId),
226  m_parentNodeId(rhs.m_parentNodeId),
227  m_leftChildNodeId(rhs.m_leftChildNodeId),
228  m_rightChildNodeId(rhs.m_rightChildNodeId),
229  m_isLeaf(rhs.m_isLeaf),
230  m_threshold(rhs.m_threshold),
231  m_variableId(rhs.m_variableId),
232  m_outcome(rhs.m_outcome)
233 {
234 }
235 
236 //------------------------------------------------------------------------------------------------------------------------------------------
237 
239 {
240  if (this != &rhs)
241  {
242  m_nodeId = rhs.m_nodeId;
246  m_isLeaf = rhs.m_isLeaf;
247  m_threshold = rhs.m_threshold;
249  m_outcome = rhs.m_outcome;
250  }
251 
252  return *this;
253 }
254 
255 //------------------------------------------------------------------------------------------------------------------------------------------
256 
258 {
259 }
260 
261 //------------------------------------------------------------------------------------------------------------------------------------------
262 //------------------------------------------------------------------------------------------------------------------------------------------
263 
264 AdaBoostDecisionTree::WeakClassifier::WeakClassifier(const TiXmlHandle *const pXmlHandle) :
265  m_weight(0.),
266  m_treeId(0)
267 {
268  for (TiXmlElement *pHeadTiXmlElement = pXmlHandle->FirstChildElement().ToElement(); pHeadTiXmlElement != NULL;
269  pHeadTiXmlElement = pHeadTiXmlElement->NextSiblingElement())
270  {
271  if ("TreeIndex" == pHeadTiXmlElement->ValueStr())
272  {
273  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeIndex", m_treeId));
274  }
275  else if ("TreeWeight" == pHeadTiXmlElement->ValueStr())
276  {
277  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeWeight", m_weight));
278  }
279  else if ("Node" == pHeadTiXmlElement->ValueStr())
280  {
281  const TiXmlHandle nodeHandle(pHeadTiXmlElement);
282  const Node *pNode = new Node(&nodeHandle);
283  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
284  }
285  }
286 }
287 
288 //------------------------------------------------------------------------------------------------------------------------------------------
289 
291  m_weight(rhs.m_weight),
292  m_treeId(rhs.m_treeId)
293 {
294  for (const auto &mapEntry : rhs.m_idToNodeMap)
295  {
296  const Node *pNode = new Node(*(mapEntry.second));
297  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
298  }
299 }
300 
301 //------------------------------------------------------------------------------------------------------------------------------------------
302 
304 {
305  if (this != &rhs)
306  {
307  for (const auto &mapEntry : rhs.m_idToNodeMap)
308  {
309  const Node *pNode = new Node(*(mapEntry.second));
310  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
311  }
312 
313  m_weight = rhs.m_weight;
314  m_treeId = rhs.m_treeId;
315  }
316 
317  return *this;
318 }
319 
320 //------------------------------------------------------------------------------------------------------------------------------------------
321 
323 {
324  for (const auto &mapEntry : m_idToNodeMap)
325  delete mapEntry.second;
326 }
327 
328 //------------------------------------------------------------------------------------------------------------------------------------------
329 
331 {
332  return this->EvaluateNode(0, features);
333 }
334 
335 //------------------------------------------------------------------------------------------------------------------------------------------
336 
338 {
339  const Node *pActiveNode(nullptr);
340 
341  if (m_idToNodeMap.find(nodeId) != m_idToNodeMap.end())
342  {
343  pActiveNode = m_idToNodeMap.at(nodeId);
344  }
345  else
346  {
347  throw StatusCodeException(STATUS_CODE_OUT_OF_RANGE);
348  }
349 
350  if (pActiveNode->IsLeaf())
351  return pActiveNode->GetOutcome();
352 
353  if (static_cast<int>(features.size()) <= pActiveNode->GetVariableId())
354  throw StatusCodeException(STATUS_CODE_NOT_FOUND);
355 
356  if (features.at(pActiveNode->GetVariableId()).Get() <= pActiveNode->GetThreshold())
357  {
358  return this->EvaluateNode(pActiveNode->GetLeftChildNodeId(), features);
359  }
360  else
361  {
362  return this->EvaluateNode(pActiveNode->GetRightChildNodeId(), features);
363  }
364 }
365 
366 //------------------------------------------------------------------------------------------------------------------------------------------
367 //------------------------------------------------------------------------------------------------------------------------------------------
368 
369 AdaBoostDecisionTree::StrongClassifier::StrongClassifier(const TiXmlHandle *const pXmlHandle)
370 {
371  TiXmlElement *pCurrentXmlElement = pXmlHandle->FirstChild().Element();
372 
373  while (pCurrentXmlElement)
374  {
375  if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
376  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
377 
378  pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
379  }
380 }
381 
382 //------------------------------------------------------------------------------------------------------------------------------------------
383 
385 {
386  for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
387  m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
388 }
389 
390 //------------------------------------------------------------------------------------------------------------------------------------------
391 
393 {
394  if (this != &rhs)
395  {
396  for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
397  m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
398  }
399 
400  return *this;
401 }
402 
403 //------------------------------------------------------------------------------------------------------------------------------------------
404 
406 {
407  for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
408  delete pWeakClassifier;
409 }
410 
411 //------------------------------------------------------------------------------------------------------------------------------------------
412 
414 {
415  double score(0.), weights(0.);
416 
417  for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
418  {
419  weights += pWeakClassifier->GetWeight();
420 
421  if (pWeakClassifier->Predict(features))
422  {
423  score += pWeakClassifier->GetWeight();
424  }
425  else
426  {
427  score -= pWeakClassifier->GetWeight();
428  }
429  }
430 
431  if (weights > std::numeric_limits<double>::epsilon())
432  {
433  score /= weights;
434  }
435  else
436  {
437  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
438  }
439 
440  return score;
441 }
442 
443 //------------------------------------------------------------------------------------------------------------------------------------------
444 
445 StatusCode AdaBoostDecisionTree::StrongClassifier::ReadComponent(TiXmlElement *pCurrentXmlElement)
446 {
447  const std::string componentName(pCurrentXmlElement->ValueStr());
448  TiXmlHandle currentHandle(pCurrentXmlElement);
449 
450  if ((std::string("Name") == componentName) || (std::string("Timestamp") == componentName))
451  return STATUS_CODE_SUCCESS;
452 
453  if (std::string("DecisionTree") == componentName)
454  {
455  m_weakClassifiers.emplace_back(new WeakClassifier(&currentHandle));
456  return STATUS_CODE_SUCCESS;
457  }
458 
459  return STATUS_CODE_INVALID_PARAMETER;
460 }
461 
462 } // namespace lar_content
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
int GetLeftChildNodeId() const
Return left child node id.
int GetVariableId() const
Return cut variable.
MvaTypes::MvaFeatureVector MvaFeatureVector
Definition: LArMvaHelper.h:75
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
cout<< "Opened file "<< fin<< " ixs= "<< ixs<< endl;if(ixs==0) hhh=(TH1F *) fff-> Get("h1")
Definition: AddMC.C:8
WeakClassifier class containing a decision tree and a weight.
double m_threshold
Threshold used for decision if decision node.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
double GetThreshold() const
Return node threshold.
StrongClassifier class used in application of adaptive boost decision tree.
Header file for the lar adaptive boosted decision tree class.
Node & operator=(const Node &rhs)
Assignment operator.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
Node class used for representing a decision tree.
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model...
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetRightChildNodeId() const
Return right child node id.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model...
pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &bdtName)
Initialize the bdt model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
int m_variableId
Variable cut on for decision if decision node.
bool IsLeaf() const
Return is the node a leaf.