LArSoft  v07_13_02
Liquid Argon Software toolkit - http://larsoft.org/
TrainMVA.C
Go to the documentation of this file.
1 #include "TFile.h"
2 #include "TTree.h"
3 #include "TMVA/Factory.h"
4 
5 #include <string>
6 #include <vector>
7 #include <map>
8 #include <iostream>
9 
10 #include "MVAPIDResult.h"
11 
12 void BuildTree(std::string inFile,std::string outFile){
13  TFile* fIn=new TFile(inFile.c_str());
14  TTree* tr=(TTree*)fIn->Get("pid/MVAPID");
15  std::vector<anab::MVAPIDResult>* mvares=0;
16  tr->SetBranchAddress("MVAResult",&mvares);
17  TFile* fOut=new TFile(outFile.c_str(),"RECREATE");
18  TTree* mvaTree = new TTree("mvaTree","mvaTree");
19 
20  float evalRatio, concentration, coreHaloRatio, conicalness;
21  float dEdxStart, dEdxEnd, dEdxEndRatio;
22  float length;
23  int isTrack, isStoppingReco;
24 
25  mvaTree->Branch("evalRatio",&evalRatio,"evalRatio/F");
26  mvaTree->Branch("concentration",&concentration,"concentration/F");
27  mvaTree->Branch("coreHaloRatio",&coreHaloRatio,"coreHaloRatio/F");
28  mvaTree->Branch("conicalness",&conicalness,"conicalness/F");
29  mvaTree->Branch("dEdxStart",&dEdxStart,"dEdxStart/F");
30  mvaTree->Branch("dEdxEnd",&dEdxEnd,"dEdxEnd/F");
31  mvaTree->Branch("dEdxEndRatio",&dEdxEndRatio,"dEdxEndRatio/F");
32  mvaTree->Branch("length",&length,"length/F");
33  mvaTree->Branch("isTrack",&isTrack,"isTrack/I");
34  mvaTree->Branch("isStoppingReco",&isStoppingReco,"isStoppingReco/I");
35 
36  for(int iEntry=0;iEntry<tr->GetEntries();++iEntry){
37  tr->GetEntry(iEntry);
38  if(!mvares->size()) continue;
39 
40  anab::MVAPIDResult* biggestTrack=&((*mvares)[0]);
41  for(unsigned int iRes=0;iRes!=mvares->size();++iRes){
42  if((((*mvares)[iRes]).nSpacePoints)>(biggestTrack->nSpacePoints)){
43  biggestTrack=&((*mvares)[iRes]);
44  }
45  }
46 
47  evalRatio=biggestTrack->evalRatio;
48  concentration=biggestTrack->concentration;
49  coreHaloRatio=biggestTrack->coreHaloRatio;
50  conicalness=biggestTrack->conicalness;
51  dEdxStart=biggestTrack->dEdxStart;
52  dEdxEnd=biggestTrack->dEdxEnd;
53  dEdxEndRatio=biggestTrack->dEdxEndRatio;
54  length=biggestTrack->length;
55  isTrack=biggestTrack->isTrack;
56  isStoppingReco=biggestTrack->isStoppingReco;
57 
58  mvaTree->Fill();
59  }
60 
61  fIn->Close();
62  fOut->Write();
63  fOut->Close();
64 }
65 
66 void TrainMVA(std::vector<std::string> signalFiles,std::vector<std::string> backgroundFiles,std::string outputFile,std::string jobName){
67 
68  TFile* fOut = new TFile(outputFile.c_str(),"RECREATE");
69  TMVA::Factory* factory = new TMVA::Factory( jobName.c_str(), fOut, "" );
70 
71  std::vector<TTree*> sigTrees;
72 
73  for(std::vector<std::string>::iterator fIter=signalFiles.begin();fIter!=signalFiles.end();++fIter){
74  TFile* fIn=new TFile(fIter->c_str());
75  factory->AddSignalTree((TTree*)fIn->Get("mvaTree"));
76  }
77 
78  for(std::vector<std::string>::iterator fIter=backgroundFiles.begin();fIter!=backgroundFiles.end();++fIter){
79  TFile* fIn=new TFile(fIter->c_str());
80  factory->AddBackgroundTree((TTree*)fIn->Get("mvaTree"));
81  }
82 
83  factory->AddVariable("evalRatio",'F');
84  factory->AddVariable("concentration",'F');
85  factory->AddVariable("coreHaloRatio",'F');
86  factory->AddVariable("conicalness",'F');
87  factory->AddVariable("dEdxStart",'F');
88  factory->AddVariable("dEdxEnd",'F');
89  factory->AddVariable("dEdxEndRatio",'F');
90 
91  factory->BookMethod( TMVA::Types::kTMlpANN, "ANN", "" );
92  factory->BookMethod( TMVA::Types::kBDT, "BDT", "" );
93  factory->TrainAllMethods();
94  factory->TestAllMethods();
95  factory->EvaluateAllMethods();
96  fOut->Write();
97  fOut->Close();
98 }
99 
100 
101 void PrintRes(std::string inFile){
102 
103  TFile* fIn=new TFile(inFile.c_str());
104  TTree* tr=(TTree*)fIn->Get("pid/ANAB");
105  std::vector<anab::MVAPIDResult>* mvares=0;
106  tr->SetBranchAddress("MVAResult",&mvares);
107 
108  for(int iEntry=0;iEntry<tr->GetEntries();++iEntry){
109  tr->GetEntry(iEntry);
110  if(!mvares->size()) continue;
111 
112  anab::MVAPIDResult* biggestTrack=&((*mvares)[0]);
113  for(unsigned int iRes=0;iRes!=mvares->size();++iRes){
114  if((((*mvares)[iRes]).nSpacePoints)>(biggestTrack->nSpacePoints)){
115  biggestTrack=&((*mvares)[iRes]);
116  }
117  }
118 
119  std::cout<<biggestTrack->mvaOutput.at(string("ANN"))<<std::endl;
120  }
121 
122  fIn->Close();
123 }
124  /*
125  std::vector<std::string> methods;
126  methods.push_back("TMlp_ANN");
127  methods.push_back("BDT");
128  //methods.push_back("Cuts");
129  //methods.push_back("kNN");
130  //methods.push_back("Likelihood");
131  gStyle->SetOptStat("0000");
132 
133  TFile fOut(outFileName.c_str(),"RECREATE");
134 
135  TSystemDirectory dir("./","./");
136  TIter nextFile(dir.GetListOfFiles());
137  TSystemFile* file;
138  while(file = (TSystemFile*)nextFile.Next()){
139  std::string fName=file->GetName();
140  std::cout<<fName<<std::endl;
141  if(fName.find("mva_")!=std::string::npos){
142  fOut.cd();
143  std::string caName=fName.substr(0,fName.size()-5);
144  TCanvas* ca=new TCanvas(caName.c_str(),caName.c_str());
145  TLegend* le=new TLegend(0.2,0.1,0.4,0.3);
146  TFile f(fName.c_str());
147  int iMethod=0;
148  for(std::vector<std::string>::iterator mIter=methods.begin();mIter!=methods.end();++mIter){
149  std::string methodFolder;
150  if(mIter->find("ANN")!=std::string::npos){
151  methodFolder="TMlpANN";
152  }
153  else{
154  methodFolder=*mIter;
155  }
156  std::cout<<(std::string("Method_")+methodFolder+"/"+*mIter+"/MVA_"+*mIter+"_rejBvsS").c_str()<<std::endl;
157  TH1D* h=(TH1D*)f.Get((std::string("Method_")+methodFolder+"/"+*mIter+"/MVA_"+*mIter+"_rejBvsS").c_str());
158  le->AddEntry(h);
159  if(iMethod==0){
160  h->SetTitle(caName.c_str());
161  h->SetLineColor(kRed);
162  h->Draw();
163  h->GetYaxis().SetRangeUser(0,1.2);
164  firstMethod=false;
165  }
166  else{
167  h->SetLineColor(iMethod==1?kBlue:kGreen);
168  h->Draw("same");
169  }
170  ++iMethod;
171  }
172  le->Draw();
173  fOut.cd();
174  std::cout<<caName.c_str()<<std::endl;
175  ca->Write(caName.c_str());
176  ca->SaveAs((std::string("./plots/")+caName+".png").c_str());
177  }
178  }
179  fOut.Close();
180 }
181 */
intermediate_table::iterator iterator
void BuildTree(std::string inFile, std::string outFile)
Definition: TrainMVA.C:12
void PrintRes(std::string inFile)
Definition: TrainMVA.C:101
void TrainMVA(std::vector< std::string > signalFiles, std::vector< std::string > backgroundFiles, std::string outputFile, std::string jobName)
Definition: TrainMVA.C:66
std::map< std::string, double > mvaOutput
Definition: MVAPIDResult.h:27