Belle II Software  release-06-01-15
SelectorMVA.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <analysis/modules/CurlTagger/SelectorMVA.h>
10 
11 #include <analysis/variables/TrackVariables.h>
12 #include <analysis/variables/MCTruthVariables.h>
13 #include <analysis/variables/Variables.h>
14 
15 //Root includes
16 #include "TVector3.h"
17 
18 using namespace Belle2;
19 using namespace CurlTagger;
20 
21 SelectorMVA::SelectorMVA(bool belleFlag, bool trainFlag)
22 {
23  m_TrainFlag = trainFlag;
24 
25  if (belleFlag) {
26  m_TFileName = "CurlTagger_TrainingData_Belle.root";
27  m_identifier = "CurlTagger_FastBDT_Belle";
28  } else {
29  m_TFileName = "CurlTagger_TrainingData_BelleII.root";
30  m_identifier = "CurlTagger_FastBDT_BelleII";
31  }
32 }
33 
34 SelectorMVA::~SelectorMVA() = default;
35 
37 {
38  if (m_TrainFlag) {
39  m_IsCurl = (Variable::genParticleIndex(iPart) == Variable::genParticleIndex(jPart) ? 1 : 0);
40  }
41  m_ChargeProduct = iPart->getCharge() * jPart->getCharge();
42  m_PPhi = iPart->getMomentum().Angle(jPart->getMomentum());
43  m_PtDiffEW = abs(Variable::particlePt(iPart) - Variable::particlePt(jPart)) / (Variable::particlePtErr(
44  iPart) + Variable::particlePtErr(jPart));
45  m_PzDiffEW = abs(Variable::particlePz(iPart) - Variable::particlePz(jPart)) / (Variable::particlePzErr(
46  iPart) + Variable::particlePzErr(jPart));
47  m_TrackD0DiffEW = abs(Variable::trackD0(iPart) - Variable::trackD0(jPart)) / (Variable::trackD0Error(
48  iPart) + Variable::trackD0Error(jPart));
49  m_TrackZ0DiffEW = abs(Variable::trackZ0(iPart) - Variable::trackZ0(jPart)) / (Variable::trackZ0Error(
50  iPart) + Variable::trackZ0Error(jPart));
51  m_TrackTanLambdaDiffEW = abs(Variable::trackTanLambda(iPart) - Variable::trackTanLambda(jPart)) / (Variable::trackTanLambdaError(
52  iPart) + Variable::trackTanLambdaError(jPart));
53  m_TrackPhi0DiffEW = abs(Variable::trackPhi0(iPart) - Variable::trackPhi0(jPart)) / (Variable::trackPhi0Error(
54  iPart) + Variable::trackPhi0Error(jPart));
55  m_TrackOmegaDiffEW = abs(Variable::trackOmega(iPart) - Variable::trackOmega(jPart)) / (Variable::trackOmegaError(
56  iPart) + Variable::trackOmegaError(jPart));
57 }
58 
59 std::vector<float> SelectorMVA::getVariables(Particle* iPart, Particle* jPart)
60 {
61  updateVariables(iPart, jPart);
63 }
64 
66 {
67  updateVariables(iPart, jPart);
68  m_TTree -> Fill();
69 }
70 
72 {
73  if (m_TrainFlag) { //make training data
74  m_TFile = TFile::Open(m_TFileName.c_str(), "RECREATE");
75  m_TTree = new TTree("ntuple", "Training Data for the Curl Tagger MVA");
76 
77  m_TTree -> Branch("PPhi" , &m_PPhi, "PPhi/F");
78  m_TTree -> Branch("ChargeProduct", &m_ChargeProduct, "ChargeProduct/F");
79  m_TTree -> Branch("PtDiffEW", &m_PtDiffEW, "PtDiffEW/F");
80  m_TTree -> Branch("PzDiffEW", &m_PzDiffEW, "PzDiffEW/F");
81  m_TTree -> Branch("TrackD0DiffEW", &m_TrackD0DiffEW, "TrackD0DiffEW/F");
82  m_TTree -> Branch("TrackZ0DiffEW", &m_TrackZ0DiffEW, "TrackZ0DiffEW/F");
83  m_TTree -> Branch("TrackTanLambdaDiffEW", &m_TrackTanLambdaDiffEW, "TrackTanLambdaDiffEW/F");
84  m_TTree -> Branch("TrackPhi0DiffEW", &m_TrackPhi0DiffEW, "TrackPhi0DiffEW/F");
85  m_TTree -> Branch("TrackOmegaDiffEW", &m_TrackOmegaDiffEW, "TrackOmegaDiffEW/F");
86 
87  m_TTree -> Branch("IsCurl", &m_IsCurl, "IsCurl/F");
88 
89  m_target_variable = "IsCurl";
90  m_variables = {"PPhi", "ChargeProduct", "PtDiffEW", "PzDiffEW", "TrackD0DiffEW", "TrackZ0DiffEW", "TrackTanLambdaDiffEW", "TrackPhi0DiffEW", "TrackOmegaDiffEW"};
91 
92  } else { // normal application
93  //load MVA
95  weightfile.getOptions(m_generalOptions);
96  m_expert.load(weightfile);
97  }
98 }
99 
101 {
102  if (m_TrainFlag) {
103  m_TFile -> cd();
104  m_TTree -> Write();
105  m_TFile -> Write();
106  m_TFile -> Close();
107 
108  //train MVA
109  MVA::GeneralOptions generalOptions;
110  generalOptions.m_datafiles = {m_TFileName};
111  generalOptions.m_identifier = m_identifier;
112  generalOptions.m_variables = m_variables;
113  generalOptions.m_target_variable = m_target_variable;
114  generalOptions.m_signal_class = 1;
115  generalOptions.m_weight_variable = ""; // sets all weights to 1 if blank
116 
117  MVA::ROOTDataset dataset(generalOptions);
118 
119  MVA::FastBDTOptions specificOptions;
120  specificOptions.m_nTrees = 1000;
121  //specificOptions.m_shrinkage = 0.10;
122  specificOptions.m_nCuts = 16;
123  specificOptions. m_nLevels = 4;
124 
125  auto teacher = new MVA::FastBDTTeacher(generalOptions, specificOptions);
126  auto weightfile = teacher->train(dataset);
128  //MVA::Weightfile::saveToXMLFile(weightfile, "test.xml");
129  }
130 }
131 
133 {
134  MVA::SingleDataset dataset(m_generalOptions, getVariables(iPart, jPart));
135  return m_expert.apply(dataset)[0];
136 }
void updateVariables(Particle *iPart, Particle *jPart)
updates the value of the MVA variable
Definition: SelectorMVA.cc:36
virtual std::vector< float > getVariables(Particle *iPart, Particle *jPart) override
returns vector of variables used by this selector.
Definition: SelectorMVA.cc:59
Float_t m_TrackTanLambdaDiffEW
error weighted track tan lambda diff difference
Definition: SelectorMVA.h:113
std::vector< std::string > m_variables
names of variables used by mva
Definition: SelectorMVA.h:87
Float_t m_TrackZ0DiffEW
error weighted track Z0 difference
Definition: SelectorMVA.h:110
virtual void initialize() override
initialize whatever needs to be initialized (root file etc)
Definition: SelectorMVA.cc:71
Float_t m_IsCurl
isCurl Truth
Definition: SelectorMVA.h:122
Float_t m_PtDiffEW
error weighted particle Pt difference
Definition: SelectorMVA.h:101
Float_t m_PPhi
angle between particle momentum vectors
Definition: SelectorMVA.h:95
Float_t m_TrackPhi0DiffEW
error weighted track Phi0 difference
Definition: SelectorMVA.h:116
TFile * m_TFile
output file for training data
Definition: SelectorMVA.h:67
Float_t m_PzDiffEW
error weighted particle Pz difference
Definition: SelectorMVA.h:104
Float_t m_TrackD0DiffEW
error weighted track D0 difference
Definition: SelectorMVA.h:107
SelectorMVA(bool belleFlag, bool trainFlag)
Constructor.
Definition: SelectorMVA.cc:21
virtual void finalize() override
finalize whatever needs to be finalized (train the MVA)
Definition: SelectorMVA.cc:100
Float_t m_ChargeProduct
charge(p1) * charge(p2)
Definition: SelectorMVA.h:98
bool m_TrainFlag
applying mva or training it
Definition: SelectorMVA.h:61
MVA::FastBDTExpert m_expert
mva expert
Definition: SelectorMVA.h:76
virtual void collectTrainingInfo(Particle *iPart, Particle *jPart) override
collect training data and save to a root file
Definition: SelectorMVA.cc:65
std::string m_TFileName
name of output file for training data
Definition: SelectorMVA.h:64
virtual float getResponse(Particle *iPart, Particle *jPart) override
Selector response that this pair of particles come from the same mc/actual particle.
Definition: SelectorMVA.cc:132
MVA::GeneralOptions m_generalOptions
mva general options (for the expert)
Definition: SelectorMVA.h:73
std::string m_target_variable
name of target variable (isCurl)
Definition: SelectorMVA.h:90
Float_t m_TrackOmegaDiffEW
error weighted track Omega difference
Definition: SelectorMVA.h:119
TTree * m_TTree
training data tree
Definition: SelectorMVA.h:70
std::string m_identifier
mva identifier
Definition: SelectorMVA.h:81
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FastBDT.cc:413
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FastBDT.cc:320
Options for the FANN MVA method.
Definition: FastBDT.h:53
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:79
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:78
Teacher for the FastBDT MVA method.
Definition: FastBDT.h:98
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
Definition: Options.h:84
int m_signal_class
Signal class which is used as signal in a classification problem.
Definition: Options.h:88
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:90
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:89
std::string m_identifier
Identifier containing the finished training.
Definition: Options.h:83
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
Definition: Dataset.h:347
Wraps the data of a single event into a Dataset.
Definition: Dataset.h:133
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
Definition: Weightfile.cc:280
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
Definition: Weightfile.cc:257
Class to store reconstructed particles.
Definition: Particle.h:74
TVector3 getMomentum() const
Returns momentum vector.
Definition: Particle.h:488
float getCharge(void) const
Returns particle charge.
Definition: Particle.cc:630
Abstract base class for different kinds of events.