LArSoft  v06_85_00
Liquid Argon Software toolkit - http://larsoft.org/
LArAdaBoostDecisionTree.cc
Go to the documentation of this file.
1 
9 #include "Helpers/XmlHelper.h"
10 
12 
13 using namespace pandora;
14 
15 namespace lar_content
16 {
17 
18 AdaBoostDecisionTree::AdaBoostDecisionTree() :
19  m_pStrongClassifier(nullptr)
20 {
21 }
22 
23 //------------------------------------------------------------------------------------------------------------------------------------------
24 
26 {
28 }
29 
30 //------------------------------------------------------------------------------------------------------------------------------------------
31 
33 {
34  if (this != &rhs)
36 
37  return *this;
38 }
39 
40 //------------------------------------------------------------------------------------------------------------------------------------------
41 
43 {
44  delete m_pStrongClassifier;
45 }
46 
47 //------------------------------------------------------------------------------------------------------------------------------------------
48 
49 StatusCode AdaBoostDecisionTree::Initialize(const std::string &bdtXmlFileName, const std::string &bdtName)
50 {
52  {
53  std::cout << "AdaBoostDecisionTree: AdaBoostDecisionTree was already initialized" << std::endl;
54  return STATUS_CODE_ALREADY_INITIALIZED;
55  }
56 
57  TiXmlDocument xmlDocument(bdtXmlFileName);
58 
59  if (!xmlDocument.LoadFile())
60  {
61  std::cout << "AdaBoostDecisionTree::Initialize - Invalid xml file." << std::endl;
62  return STATUS_CODE_INVALID_PARAMETER;
63  }
64 
65  const TiXmlHandle xmlDocumentHandle(&xmlDocument);
66  TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
67 
68  while (pContainerXmlNode)
69  {
70  if (pContainerXmlNode->ValueStr() != "AdaBoostDecisionTree")
71  return STATUS_CODE_FAILURE;
72 
73  const TiXmlHandle currentHandle(pContainerXmlNode);
74 
75  std::string currentName;
76  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle, "Name", currentName));
77 
78  if (currentName.empty() || (currentName.size() > 1000))
79  {
80  std::cout << "AdaBoostDecisionTree::Initialize - Implausible AdaBoostDecisionTree name extracted from xml." << std::endl;
81  return STATUS_CODE_INVALID_PARAMETER;
82  }
83 
84  if (currentName == bdtName)
85  break;
86 
87  pContainerXmlNode = pContainerXmlNode->NextSibling();
88  }
89 
90  if (!pContainerXmlNode)
91  {
92  std::cout << "AdaBoostDecisionTree: Could not find an AdaBoostDecisionTree of name " << bdtName << std::endl;
93  return STATUS_CODE_NOT_FOUND;
94  }
95 
96  const TiXmlHandle xmlHandle(pContainerXmlNode);
97 
98  try
99  {
100  m_pStrongClassifier = new StrongClassifier(&xmlHandle);
101  }
102  catch (StatusCodeException &statusCodeException)
103  {
104  delete m_pStrongClassifier;
105 
106  if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
107  std::cout << "AdaBoostDecisionTree: Initialization failure, unknown component in xml file." << std::endl;
108 
109  if (STATUS_CODE_FAILURE == statusCodeException.GetStatusCode())
110  std::cout << "AdaBoostDecisionTree: Node definition does not contain expected leaf or branch variables." << std::endl;
111 
112  return statusCodeException.GetStatusCode();
113  }
114 
115  return STATUS_CODE_SUCCESS;
116 }
117 
118 //------------------------------------------------------------------------------------------------------------------------------------------
119 
121 {
122  return ((this->CalculateScore(features) > 0.) ? true : false);
123 }
124 
125 //------------------------------------------------------------------------------------------------------------------------------------------
126 
128 {
129  return this->CalculateScore(features);
130 }
131 
132 //------------------------------------------------------------------------------------------------------------------------------------------
133 
135 {
136  // ATTN: BDT score, once normalised by total weight, is confined to the range -1 to +1. This linear mapping places the score in the
137  // range 0 to 1 so that it may be interpreted as a probability.
138  return (this->CalculateScore(features) + 1.) * 0.5;
139 }
140 
141 //------------------------------------------------------------------------------------------------------------------------------------------
142 
144 {
145  if (!m_pStrongClassifier)
146  {
147  std::cout << "AdaBoostDecisionTree: Attempting to use an uninitialized bdt" << std::endl;
148  throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
149  }
150 
151  try
152  {
153  // TODO: Add consistency check for number of features, bearing in mind not all features in a bdt may be used
154  return m_pStrongClassifier->Predict(features);
155  }
156  catch (StatusCodeException &statusCodeException)
157  {
158  if (STATUS_CODE_NOT_FOUND == statusCodeException.GetStatusCode())
159  {
160  std::cout << "AdaBoostDecisionTree: Caught exception thrown when trying to cut on an unknown variable." << std::endl;
161  }
162  else if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
163  {
164  std::cout << "AdaBoostDecisionTree: Caught exception thrown when classifier weights sum to zero indicating defunct classifier." << std::endl;
165  }
166  else if (STATUS_CODE_OUT_OF_RANGE == statusCodeException.GetStatusCode())
167  {
168  std::cout << "AdaBoostDecisionTree: Caught exception thrown when heirarchy in decision tree is incomplete." << std::endl;
169  }
170  else
171  {
172  std::cout << "AdaBoostDecisionTree: Unexpected exception thrown." << std::endl;
173  }
174 
175  throw statusCodeException;
176  }
177 }
178 
179 //------------------------------------------------------------------------------------------------------------------------------------------
180 //------------------------------------------------------------------------------------------------------------------------------------------
181 
182 AdaBoostDecisionTree::Node::Node(const TiXmlHandle *const pXmlHandle) :
183  m_nodeId(0),
184  m_parentNodeId(0),
185  m_leftChildNodeId(0),
186  m_rightChildNodeId(0),
187  m_isLeaf(false),
188  m_threshold(0.),
189  m_variableId(0),
190  m_outcome(false)
191 {
192  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "NodeId", m_nodeId));
193  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "ParentNodeId", m_parentNodeId));
194 
195  const StatusCode leftChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "LeftChildNodeId", m_leftChildNodeId));
196  const StatusCode rightChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "RightChildNodeId", m_rightChildNodeId));
197  const StatusCode thresholdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Threshold", m_threshold));
198  const StatusCode variableIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "VariableId", m_variableId));
199  const StatusCode outcomeStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Outcome", m_outcome));
200 
201  if (STATUS_CODE_SUCCESS == leftChildNodeIdStatusCode || STATUS_CODE_SUCCESS == rightChildNodeIdStatusCode ||
202  STATUS_CODE_SUCCESS == thresholdStatusCode || STATUS_CODE_SUCCESS == variableIdStatusCode)
203  {
204  m_isLeaf = false;
205  m_outcome = false;
206  }
207  else if (outcomeStatusCode == STATUS_CODE_SUCCESS)
208  {
209  m_isLeaf = true;
210  m_leftChildNodeId = std::numeric_limits<int>::max();
211  m_rightChildNodeId = std::numeric_limits<int>::max();
212  m_threshold = std::numeric_limits<double>::max();
213  m_variableId = std::numeric_limits<int>::max();
214  }
215  else
216  {
217  throw StatusCodeException(STATUS_CODE_FAILURE);
218  }
219 }
220 
221 //------------------------------------------------------------------------------------------------------------------------------------------
222 
224  m_nodeId(rhs.m_nodeId),
225  m_parentNodeId(rhs.m_parentNodeId),
226  m_leftChildNodeId(rhs.m_leftChildNodeId),
227  m_rightChildNodeId(rhs.m_rightChildNodeId),
228  m_isLeaf(rhs.m_isLeaf),
229  m_threshold(rhs.m_threshold),
230  m_variableId(rhs.m_variableId),
231  m_outcome(rhs.m_outcome)
232 {
233 }
234 
235 //------------------------------------------------------------------------------------------------------------------------------------------
236 
238 {
239  if (this != &rhs)
240  {
241  m_nodeId = rhs.m_nodeId;
245  m_isLeaf = rhs.m_isLeaf;
246  m_threshold = rhs.m_threshold;
248  m_outcome = rhs.m_outcome;
249  }
250 
251  return *this;
252 }
253 
254 //------------------------------------------------------------------------------------------------------------------------------------------
255 
257 {
258 }
259 
260 //------------------------------------------------------------------------------------------------------------------------------------------
261 //------------------------------------------------------------------------------------------------------------------------------------------
262 
263 AdaBoostDecisionTree::WeakClassifier::WeakClassifier(const TiXmlHandle *const pXmlHandle) :
264  m_weight(0.),
265  m_treeId(0)
266 {
267  for (TiXmlElement *pHeadTiXmlElement = pXmlHandle->FirstChildElement().ToElement(); pHeadTiXmlElement != NULL; pHeadTiXmlElement = pHeadTiXmlElement->NextSiblingElement())
268  {
269  if ("TreeIndex" == pHeadTiXmlElement->ValueStr())
270  {
271  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeIndex", m_treeId));
272  }
273  else if ("TreeWeight" == pHeadTiXmlElement->ValueStr())
274  {
275  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeWeight", m_weight));
276  }
277  else if ("Node" == pHeadTiXmlElement->ValueStr())
278  {
279  const TiXmlHandle nodeHandle(pHeadTiXmlElement);
280  const Node *pNode = new Node(&nodeHandle);
281  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
282  }
283  }
284 }
285 
286 //------------------------------------------------------------------------------------------------------------------------------------------
287 
289  m_weight(rhs.m_weight),
290  m_treeId(rhs.m_treeId)
291 {
292  for (const auto &mapEntry : rhs.m_idToNodeMap)
293  {
294  const Node *pNode = new Node(*(mapEntry.second));
295  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
296  }
297 }
298 
299 //------------------------------------------------------------------------------------------------------------------------------------------
300 
302 {
303  if (this != &rhs)
304  {
305  for (const auto &mapEntry : rhs.m_idToNodeMap)
306  {
307  const Node *pNode = new Node(*(mapEntry.second));
308  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
309  }
310 
311  m_weight = rhs.m_weight;
312  m_treeId = rhs.m_treeId;
313  }
314 
315  return *this;
316 }
317 
318 //------------------------------------------------------------------------------------------------------------------------------------------
319 
321 {
322  for (const auto &mapEntry : m_idToNodeMap)
323  delete mapEntry.second;
324 }
325 
326 //------------------------------------------------------------------------------------------------------------------------------------------
327 
329 {
330  return this->EvaluateNode(0, features);
331 }
332 
333 //------------------------------------------------------------------------------------------------------------------------------------------
334 
336 {
337  const Node *pActiveNode(nullptr);
338 
339  if (m_idToNodeMap.find(nodeId) != m_idToNodeMap.end())
340  {
341  pActiveNode = m_idToNodeMap.at(nodeId);
342  }
343  else
344  {
345  throw StatusCodeException(STATUS_CODE_OUT_OF_RANGE);
346  }
347 
348  if (pActiveNode->IsLeaf())
349  return pActiveNode->GetOutcome();
350 
351  if (static_cast<int>(features.size()) <= pActiveNode->GetVariableId())
352  throw StatusCodeException(STATUS_CODE_NOT_FOUND);
353 
354  if (features.at(pActiveNode->GetVariableId()).Get() <= pActiveNode->GetThreshold())
355  {
356  return this->EvaluateNode(pActiveNode->GetLeftChildNodeId(), features);
357  }
358  else
359  {
360  return this->EvaluateNode(pActiveNode->GetRightChildNodeId(), features);
361  }
362 }
363 
364 //------------------------------------------------------------------------------------------------------------------------------------------
365 //------------------------------------------------------------------------------------------------------------------------------------------
366 
367 AdaBoostDecisionTree::StrongClassifier::StrongClassifier(const TiXmlHandle *const pXmlHandle)
368 {
369  TiXmlElement *pCurrentXmlElement = pXmlHandle->FirstChild().Element();
370 
371  while (pCurrentXmlElement)
372  {
373  if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
374  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
375 
376  pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
377  }
378 }
379 
380 //------------------------------------------------------------------------------------------------------------------------------------------
381 
383 {
384  for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
385  m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
386 }
387 
388 //------------------------------------------------------------------------------------------------------------------------------------------
389 
391 {
392  if (this != &rhs)
393  {
394  for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
395  m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
396  }
397 
398  return *this;
399 }
400 
401 //------------------------------------------------------------------------------------------------------------------------------------------
402 
404 {
405  for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
406  delete pWeakClassifier;
407 }
408 
409 //------------------------------------------------------------------------------------------------------------------------------------------
410 
412 {
413  double score(0.), weights(0.);
414 
415  for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
416  {
417  weights += pWeakClassifier->GetWeight();
418 
419  if (pWeakClassifier->Predict(features))
420  {
421  score += pWeakClassifier->GetWeight();
422  }
423  else
424  {
425  score -= pWeakClassifier->GetWeight();
426  }
427  }
428 
429  if (weights > std::numeric_limits<double>::epsilon())
430  {
431  score /= weights;
432  }
433  else
434  {
435  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
436  }
437 
438  return score;
439 }
440 
441 //------------------------------------------------------------------------------------------------------------------------------------------
442 
443 StatusCode AdaBoostDecisionTree::StrongClassifier::ReadComponent(TiXmlElement *pCurrentXmlElement)
444 {
445  const std::string componentName(pCurrentXmlElement->ValueStr());
446  TiXmlHandle currentHandle(pCurrentXmlElement);
447 
448  if ((std::string("Name") == componentName) || (std::string("Timestamp") == componentName))
449  return STATUS_CODE_SUCCESS;
450 
451  if (std::string("DecisionTree") == componentName)
452  {
453  m_weakClassifiers.emplace_back(new WeakClassifier(&currentHandle));
454  return STATUS_CODE_SUCCESS;
455  }
456 
457  return STATUS_CODE_INVALID_PARAMETER;
458 }
459 
460 } // namespace lar_content
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
int GetLeftChildNodeId() const
Return left child node id.
int GetVariableId() const
Return cut variable.
MvaTypes::MvaFeatureVector MvaFeatureVector
Definition: LArMvaHelper.h:58
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
WeakClassifier class containing a decision tree and a weight.
double m_threshold
Threshold used for decision if decision node.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
cout<< "Opened file "<< fin<< " ixs= "<< ixs<< endl;if(ixs==0) hhh=(TH1F *) fff-> Get("h1")
Definition: AddMC.C:8
double GetThreshold() const
Return node threshold.
Int_t max
Definition: plot.C:27
StrongClassifier class used in application of adaptive boost decision tree.
Header file for the lar adaptive boosted decision tree class.
Node & operator=(const Node &rhs)
Assignment operator.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
Node class used for representing a decision tree.
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model...
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetRightChildNodeId() const
Return right child node id.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model...
pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &bdtName)
Initialize the bdt model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
int m_variableId
Variable cut on for decision if decision node.
bool IsLeaf() const
Return is the node a leaf.