LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
LArAdaBoostDecisionTree.h
Go to the documentation of this file.
1 
8 #ifndef LAR_ADABOOST_DECISION_TREE_H
9 #define LAR_ADABOOST_DECISION_TREE_H 1
10 
12 
14 
15 #include "Pandora/StatusCodes.h"
16 
17 #include <functional>
18 #include <map>
19 #include <vector>
20 
21 namespace lar_content
22 {
23 
28 {
29 public:
34 
41 
48 
53 
62  pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &bdtName);
63 
71  bool Classify(const LArMvaHelper::MvaFeatureVector &features) const;
72 
81 
89  double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const;
90 
91 private:
95  class Node
96  {
97  public:
103  Node(const pandora::TiXmlHandle *const pXmlHandle);
104 
110  Node(const Node &rhs);
111 
117  Node &operator=(const Node &rhs);
118 
122  ~Node();
123 
129  int GetNodeId() const;
130 
136  int GetParentNodeId() const;
137 
143  int GetLeftChildNodeId() const;
144 
150  int GetRightChildNodeId() const;
151 
157  bool IsLeaf() const;
158 
164  double GetThreshold() const;
165 
171  int GetVariableId() const;
172 
178  bool GetOutcome() const;
179 
180  private:
181  int m_nodeId;
185  bool m_isLeaf;
186  double m_threshold;
188  bool m_outcome;
189  };
190 
191  typedef std::map<int, const Node *> IdToNodeMap;
192 
197  {
198  public:
204  WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle);
205 
211  WeakClassifier(const WeakClassifier &rhs);
212 
219 
223  ~WeakClassifier();
224 
232  bool Predict(const LArMvaHelper::MvaFeatureVector &features) const;
233 
242  bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const;
243 
249  double GetWeight() const;
250 
256  int GetTreeId() const;
257 
258  private:
259  IdToNodeMap m_idToNodeMap;
260  double m_weight;
261  int m_treeId;
262  };
263 
264  typedef std::vector<const WeakClassifier *> WeakClassifiers;
265 
270  {
271  public:
277  StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle);
278 
285 
292 
296  ~StrongClassifier();
297 
305  double Predict(const LArMvaHelper::MvaFeatureVector &features) const;
306 
307  private:
311  pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement);
312 
313  WeakClassifiers m_weakClassifiers;
314  };
315 
323  double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const;
324 
326 };
327 
328 //------------------------------------------------------------------------------------------------------------------------------------------
329 
331 {
332  return m_nodeId;
333 }
334 
335 //------------------------------------------------------------------------------------------------------------------------------------------
336 
338 {
339  return m_parentNodeId;
340 }
341 
342 //------------------------------------------------------------------------------------------------------------------------------------------
343 
345 {
346  return m_leftChildNodeId;
347 }
348 
349 //------------------------------------------------------------------------------------------------------------------------------------------
350 
352 {
353  return m_rightChildNodeId;
354 }
355 
356 //------------------------------------------------------------------------------------------------------------------------------------------
357 
359 {
360  return m_isLeaf;
361 }
362 
363 //------------------------------------------------------------------------------------------------------------------------------------------
364 
366 {
367  return m_threshold;
368 }
369 
370 //------------------------------------------------------------------------------------------------------------------------------------------
371 
373 {
374  return m_variableId;
375 }
376 
377 //------------------------------------------------------------------------------------------------------------------------------------------
378 
380 {
381  return m_outcome;
382 }
383 
384 //------------------------------------------------------------------------------------------------------------------------------------------
385 
387 {
388  return m_weight;
389 }
390 
391 //------------------------------------------------------------------------------------------------------------------------------------------
392 
394 {
395  return m_treeId;
396 }
397 
398 } // namespace lar_content
399 
400 #endif // #ifndef LAR_ADABOOST_DECISION_TREE_H
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
int GetLeftChildNodeId() const
Return left child node id.
int GetVariableId() const
Return cut variable.
MvaTypes::MvaFeatureVector MvaFeatureVector
Definition: LArMvaHelper.h:75
double GetWeight() const
Get boost weight for weak classifier.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
MvaInterface class.
WeakClassifier class containing a decision tree and a weight.
double m_threshold
Threshold used for decision if decision node.
double GetThreshold() const
Return node threshold.
StrongClassifier class used in application of adaptive boost decision tree.
Node & operator=(const Node &rhs)
Assignment operator.
int GetParentNodeId() const
Return parent node id.
std::vector< const WeakClassifier * > WeakClassifiers
Node class used for representing a decision tree.
std::map< int, const Node * > IdToNodeMap
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model...
int GetRightChildNodeId() const
Return right child node id.
int GetTreeId() const
Get tree id for weak classifier.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model...
pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &bdtName)
Initialize the bdt model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
Header file for the lar multivariate analysis interface class.
int m_variableId
Variable cut on for decision if decision node.
bool IsLeaf() const
Return is the node a leaf.