LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
TrainMVA.C File Reference
#include "TFile.h"
#include "TMVA/Factory.h"
#include "TTree.h"
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "MVAPIDResult.h"

Go to the source code of this file.

Functions

void BuildTree (std::string inFile, std::string outFile)
 
void TrainMVA (std::vector< std::string > signalFiles, std::vector< std::string > backgroundFiles, std::string outputFile, std::string jobName)
 
void PrintRes (std::string inFile)
 

Function Documentation

void BuildTree ( std::string  inFile,
std::string  outFile 
)

Definition at line 12 of file TrainMVA.C.

References anab::MVAPIDResult::concentration, anab::MVAPIDResult::conicalness, anab::MVAPIDResult::coreHaloRatio, anab::MVAPIDResult::dEdxEnd, anab::MVAPIDResult::dEdxEndRatio, anab::MVAPIDResult::dEdxStart, anab::MVAPIDResult::evalRatio, anab::MVAPIDResult::isStoppingReco, anab::MVAPIDResult::isTrack, anab::MVAPIDResult::length, and anab::MVAPIDResult::nSpacePoints.

13 {
14  TFile* fIn = new TFile(inFile.c_str());
15  TTree* tr = (TTree*)fIn->Get("pid/MVAPID");
16  std::vector<anab::MVAPIDResult>* mvares = 0;
17  tr->SetBranchAddress("MVAResult", &mvares);
18  TFile* fOut = new TFile(outFile.c_str(), "RECREATE");
19  TTree* mvaTree = new TTree("mvaTree", "mvaTree");
20 
21  float evalRatio, concentration, coreHaloRatio, conicalness;
22  float dEdxStart, dEdxEnd, dEdxEndRatio;
23  float length;
24  int isTrack, isStoppingReco;
25 
26  mvaTree->Branch("evalRatio", &evalRatio, "evalRatio/F");
27  mvaTree->Branch("concentration", &concentration, "concentration/F");
28  mvaTree->Branch("coreHaloRatio", &coreHaloRatio, "coreHaloRatio/F");
29  mvaTree->Branch("conicalness", &conicalness, "conicalness/F");
30  mvaTree->Branch("dEdxStart", &dEdxStart, "dEdxStart/F");
31  mvaTree->Branch("dEdxEnd", &dEdxEnd, "dEdxEnd/F");
32  mvaTree->Branch("dEdxEndRatio", &dEdxEndRatio, "dEdxEndRatio/F");
33  mvaTree->Branch("length", &length, "length/F");
34  mvaTree->Branch("isTrack", &isTrack, "isTrack/I");
35  mvaTree->Branch("isStoppingReco", &isStoppingReco, "isStoppingReco/I");
36 
37  for (int iEntry = 0; iEntry < tr->GetEntries(); ++iEntry) {
38  tr->GetEntry(iEntry);
39  if (!mvares->size()) continue;
40 
41  anab::MVAPIDResult* biggestTrack = &((*mvares)[0]);
42  for (unsigned int iRes = 0; iRes != mvares->size(); ++iRes) {
43  if ((((*mvares)[iRes]).nSpacePoints) > (biggestTrack->nSpacePoints)) {
44  biggestTrack = &((*mvares)[iRes]);
45  }
46  }
47 
48  evalRatio = biggestTrack->evalRatio;
49  concentration = biggestTrack->concentration;
50  coreHaloRatio = biggestTrack->coreHaloRatio;
51  conicalness = biggestTrack->conicalness;
52  dEdxStart = biggestTrack->dEdxStart;
53  dEdxEnd = biggestTrack->dEdxEnd;
54  dEdxEndRatio = biggestTrack->dEdxEndRatio;
55  length = biggestTrack->length;
56  isTrack = biggestTrack->isTrack;
57  isStoppingReco = biggestTrack->isStoppingReco;
58 
59  mvaTree->Fill();
60  }
61 
62  fIn->Close();
63  fOut->Write();
64  fOut->Close();
65 }
void PrintRes ( std::string  inFile)

Definition at line 108 of file TrainMVA.C.

References anab::MVAPIDResult::mvaOutput, and anab::MVAPIDResult::nSpacePoints.

109 {
110 
111  TFile* fIn = new TFile(inFile.c_str());
112  TTree* tr = (TTree*)fIn->Get("pid/ANAB");
113  std::vector<anab::MVAPIDResult>* mvares = 0;
114  tr->SetBranchAddress("MVAResult", &mvares);
115 
116  for (int iEntry = 0; iEntry < tr->GetEntries(); ++iEntry) {
117  tr->GetEntry(iEntry);
118  if (!mvares->size()) continue;
119 
120  anab::MVAPIDResult* biggestTrack = &((*mvares)[0]);
121  for (unsigned int iRes = 0; iRes != mvares->size(); ++iRes) {
122  if ((((*mvares)[iRes]).nSpacePoints) > (biggestTrack->nSpacePoints)) {
123  biggestTrack = &((*mvares)[iRes]);
124  }
125  }
126 
127  std::cout << biggestTrack->mvaOutput.at(string("ANN")) << std::endl;
128  }
129 
130  fIn->Close();
131 }
std::map< std::string, double > mvaOutput
Definition: MVAPIDResult.h:27
void TrainMVA ( std::vector< std::string >  signalFiles,
std::vector< std::string >  backgroundFiles,
std::string  outputFile,
std::string  jobName 
)

Definition at line 67 of file TrainMVA.C.

71 {
72 
73  TFile* fOut = new TFile(outputFile.c_str(), "RECREATE");
74  TMVA::Factory* factory = new TMVA::Factory(jobName.c_str(), fOut, "");
75 
76  std::vector<TTree*> sigTrees;
77 
78  for (std::vector<std::string>::iterator fIter = signalFiles.begin(); fIter != signalFiles.end();
79  ++fIter) {
80  TFile* fIn = new TFile(fIter->c_str());
81  factory->AddSignalTree((TTree*)fIn->Get("mvaTree"));
82  }
83 
84  for (std::vector<std::string>::iterator fIter = backgroundFiles.begin();
85  fIter != backgroundFiles.end();
86  ++fIter) {
87  TFile* fIn = new TFile(fIter->c_str());
88  factory->AddBackgroundTree((TTree*)fIn->Get("mvaTree"));
89  }
90 
91  factory->AddVariable("evalRatio", 'F');
92  factory->AddVariable("concentration", 'F');
93  factory->AddVariable("coreHaloRatio", 'F');
94  factory->AddVariable("conicalness", 'F');
95  factory->AddVariable("dEdxStart", 'F');
96  factory->AddVariable("dEdxEnd", 'F');
97  factory->AddVariable("dEdxEndRatio", 'F');
98 
99  factory->BookMethod(TMVA::Types::kTMlpANN, "ANN", "");
100  factory->BookMethod(TMVA::Types::kBDT, "BDT", "");
101  factory->TrainAllMethods();
102  factory->TestAllMethods();
103  factory->EvaluateAllMethods();
104  fOut->Write();
105  fOut->Close();
106 }
intermediate_table::iterator iterator