LArSoft  v09_90_00
Liquid Argon Software toolkit - https://larsoft.org/
TrainMVA.C
Go to the documentation of this file.
1 #include "TFile.h"
2 #include "TMVA/Factory.h"
3 #include "TTree.h"
4 
5 #include <iostream>
6 #include <map>
7 #include <string>
8 #include <vector>
9 
10 #include "MVAPIDResult.h"
11 
12 void BuildTree(std::string inFile, std::string outFile)
13 {
14  TFile* fIn = new TFile(inFile.c_str());
15  TTree* tr = (TTree*)fIn->Get("pid/MVAPID");
16  std::vector<anab::MVAPIDResult>* mvares = 0;
17  tr->SetBranchAddress("MVAResult", &mvares);
18  TFile* fOut = new TFile(outFile.c_str(), "RECREATE");
19  TTree* mvaTree = new TTree("mvaTree", "mvaTree");
20 
21  float evalRatio, concentration, coreHaloRatio, conicalness;
22  float dEdxStart, dEdxEnd, dEdxEndRatio;
23  float length;
24  int isTrack, isStoppingReco;
25 
26  mvaTree->Branch("evalRatio", &evalRatio, "evalRatio/F");
27  mvaTree->Branch("concentration", &concentration, "concentration/F");
28  mvaTree->Branch("coreHaloRatio", &coreHaloRatio, "coreHaloRatio/F");
29  mvaTree->Branch("conicalness", &conicalness, "conicalness/F");
30  mvaTree->Branch("dEdxStart", &dEdxStart, "dEdxStart/F");
31  mvaTree->Branch("dEdxEnd", &dEdxEnd, "dEdxEnd/F");
32  mvaTree->Branch("dEdxEndRatio", &dEdxEndRatio, "dEdxEndRatio/F");
33  mvaTree->Branch("length", &length, "length/F");
34  mvaTree->Branch("isTrack", &isTrack, "isTrack/I");
35  mvaTree->Branch("isStoppingReco", &isStoppingReco, "isStoppingReco/I");
36 
37  for (int iEntry = 0; iEntry < tr->GetEntries(); ++iEntry) {
38  tr->GetEntry(iEntry);
39  if (!mvares->size()) continue;
40 
41  anab::MVAPIDResult* biggestTrack = &((*mvares)[0]);
42  for (unsigned int iRes = 0; iRes != mvares->size(); ++iRes) {
43  if ((((*mvares)[iRes]).nSpacePoints) > (biggestTrack->nSpacePoints)) {
44  biggestTrack = &((*mvares)[iRes]);
45  }
46  }
47 
48  evalRatio = biggestTrack->evalRatio;
49  concentration = biggestTrack->concentration;
50  coreHaloRatio = biggestTrack->coreHaloRatio;
51  conicalness = biggestTrack->conicalness;
52  dEdxStart = biggestTrack->dEdxStart;
53  dEdxEnd = biggestTrack->dEdxEnd;
54  dEdxEndRatio = biggestTrack->dEdxEndRatio;
55  length = biggestTrack->length;
56  isTrack = biggestTrack->isTrack;
57  isStoppingReco = biggestTrack->isStoppingReco;
58 
59  mvaTree->Fill();
60  }
61 
62  fIn->Close();
63  fOut->Write();
64  fOut->Close();
65 }
66 
67 void TrainMVA(std::vector<std::string> signalFiles,
68  std::vector<std::string> backgroundFiles,
69  std::string outputFile,
70  std::string jobName)
71 {
72 
73  TFile* fOut = new TFile(outputFile.c_str(), "RECREATE");
74  TMVA::Factory* factory = new TMVA::Factory(jobName.c_str(), fOut, "");
75 
76  std::vector<TTree*> sigTrees;
77 
78  for (std::vector<std::string>::iterator fIter = signalFiles.begin(); fIter != signalFiles.end();
79  ++fIter) {
80  TFile* fIn = new TFile(fIter->c_str());
81  factory->AddSignalTree((TTree*)fIn->Get("mvaTree"));
82  }
83 
84  for (std::vector<std::string>::iterator fIter = backgroundFiles.begin();
85  fIter != backgroundFiles.end();
86  ++fIter) {
87  TFile* fIn = new TFile(fIter->c_str());
88  factory->AddBackgroundTree((TTree*)fIn->Get("mvaTree"));
89  }
90 
91  factory->AddVariable("evalRatio", 'F');
92  factory->AddVariable("concentration", 'F');
93  factory->AddVariable("coreHaloRatio", 'F');
94  factory->AddVariable("conicalness", 'F');
95  factory->AddVariable("dEdxStart", 'F');
96  factory->AddVariable("dEdxEnd", 'F');
97  factory->AddVariable("dEdxEndRatio", 'F');
98 
99  factory->BookMethod(TMVA::Types::kTMlpANN, "ANN", "");
100  factory->BookMethod(TMVA::Types::kBDT, "BDT", "");
101  factory->TrainAllMethods();
102  factory->TestAllMethods();
103  factory->EvaluateAllMethods();
104  fOut->Write();
105  fOut->Close();
106 }
107 
108 void PrintRes(std::string inFile)
109 {
110 
111  TFile* fIn = new TFile(inFile.c_str());
112  TTree* tr = (TTree*)fIn->Get("pid/ANAB");
113  std::vector<anab::MVAPIDResult>* mvares = 0;
114  tr->SetBranchAddress("MVAResult", &mvares);
115 
116  for (int iEntry = 0; iEntry < tr->GetEntries(); ++iEntry) {
117  tr->GetEntry(iEntry);
118  if (!mvares->size()) continue;
119 
120  anab::MVAPIDResult* biggestTrack = &((*mvares)[0]);
121  for (unsigned int iRes = 0; iRes != mvares->size(); ++iRes) {
122  if ((((*mvares)[iRes]).nSpacePoints) > (biggestTrack->nSpacePoints)) {
123  biggestTrack = &((*mvares)[iRes]);
124  }
125  }
126 
127  std::cout << biggestTrack->mvaOutput.at(string("ANN")) << std::endl;
128  }
129 
130  fIn->Close();
131 }
132 /*
133  std::vector<std::string> methods;
134  methods.push_back("TMlp_ANN");
135  methods.push_back("BDT");
136  //methods.push_back("Cuts");
137  //methods.push_back("kNN");
138  //methods.push_back("Likelihood");
139  gStyle->SetOptStat("0000");
140 
141  TFile fOut(outFileName.c_str(),"RECREATE");
142 
143  TSystemDirectory dir("./","./");
144  TIter nextFile(dir.GetListOfFiles());
145  TSystemFile* file;
146  while(file = (TSystemFile*)nextFile.Next()){
147  std::string fName=file->GetName();
148  std::cout<<fName<<std::endl;
149  if(fName.find("mva_")!=std::string::npos){
150  fOut.cd();
151  std::string caName=fName.substr(0,fName.size()-5);
152  TCanvas* ca=new TCanvas(caName.c_str(),caName.c_str());
153  TLegend* le=new TLegend(0.2,0.1,0.4,0.3);
154  TFile f(fName.c_str());
155  int iMethod=0;
156  for(std::vector<std::string>::iterator mIter=methods.begin();mIter!=methods.end();++mIter){
157  std::string methodFolder;
158  if(mIter->find("ANN")!=std::string::npos){
159  methodFolder="TMlpANN";
160  }
161  else{
162  methodFolder=*mIter;
163  }
164  std::cout<<(std::string("Method_")+methodFolder+"/"+*mIter+"/MVA_"+*mIter+"_rejBvsS").c_str()<<std::endl;
165  TH1D* h=(TH1D*)f.Get((std::string("Method_")+methodFolder+"/"+*mIter+"/MVA_"+*mIter+"_rejBvsS").c_str());
166  le->AddEntry(h);
167  if(iMethod==0){
168  h->SetTitle(caName.c_str());
169  h->SetLineColor(kRed);
170  h->Draw();
171  h->GetYaxis().SetRangeUser(0,1.2);
172  firstMethod=false;
173  }
174  else{
175  h->SetLineColor(iMethod==1?kBlue:kGreen);
176  h->Draw("same");
177  }
178  ++iMethod;
179  }
180  le->Draw();
181  fOut.cd();
182  std::cout<<caName.c_str()<<std::endl;
183  ca->Write(caName.c_str());
184  ca->SaveAs((std::string("./plots/")+caName+".png").c_str());
185  }
186  }
187  fOut.Close();
188 }
189 */
intermediate_table::iterator iterator
std::map< std::string, double > mvaOutput
Definition: MVAPIDResult.h:27
void BuildTree(std::string inFile, std::string outFile)
Definition: TrainMVA.C:12
void PrintRes(std::string inFile)
Definition: TrainMVA.C:108
void TrainMVA(std::vector< std::string > signalFiles, std::vector< std::string > backgroundFiles, std::string outputFile, std::string jobName)
Definition: TrainMVA.C:67