Belle II Software  release-05-01-25
SelectorMVA.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2018 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Marcel Hohmann *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #include <analysis/modules/CurlTagger/SelectorMVA.h>
12 
13 #include <analysis/variables/TrackVariables.h>
14 #include <analysis/variables/MCTruthVariables.h>
15 #include <analysis/variables/Variables.h>
16 
17 //Root includes
18 #include "TVector3.h"
19 
20 using namespace Belle2;
21 using namespace CurlTagger;
22 
23 SelectorMVA::SelectorMVA(bool belleFlag, bool trainFlag)
24 {
25  m_TrainFlag = trainFlag;
26 
27  if (belleFlag) {
28  m_TFileName = "CurlTagger_TrainingData_Belle.root";
29  m_identifier = "CurlTagger_FastBDT_Belle";
30  } else {
31  m_TFileName = "CurlTagger_TrainingData_BelleII.root";
32  m_identifier = "CurlTagger_FastBDT_BelleII";
33  }
34 }
35 
36 SelectorMVA::~SelectorMVA() = default;
37 
39 {
40  if (m_TrainFlag) {
41  m_IsCurl = (Variable::genParticleIndex(iPart) == Variable::genParticleIndex(jPart) ? 1 : 0);
42  }
43  m_ChargeProduct = iPart->getCharge() * jPart->getCharge();
44  m_PPhi = iPart->getMomentum().Angle(jPart->getMomentum());
45  m_PtDiffEW = abs(Variable::particlePt(iPart) - Variable::particlePt(jPart)) / (Variable::particlePtErr(
46  iPart) + Variable::particlePtErr(jPart));
47  m_PzDiffEW = abs(Variable::particlePz(iPart) - Variable::particlePz(jPart)) / (Variable::particlePzErr(
48  iPart) + Variable::particlePzErr(jPart));
49  m_TrackD0DiffEW = abs(Variable::trackD0(iPart) - Variable::trackD0(jPart)) / (Variable::trackD0Error(
50  iPart) + Variable::trackD0Error(jPart));
51  m_TrackZ0DiffEW = abs(Variable::trackZ0(iPart) - Variable::trackZ0(jPart)) / (Variable::trackZ0Error(
52  iPart) + Variable::trackZ0Error(jPart));
53  m_TrackTanLambdaDiffEW = abs(Variable::trackTanLambda(iPart) - Variable::trackTanLambda(jPart)) / (Variable::trackTanLambdaError(
54  iPart) + Variable::trackTanLambdaError(jPart));
55  m_TrackPhi0DiffEW = abs(Variable::trackPhi0(iPart) - Variable::trackPhi0(jPart)) / (Variable::trackPhi0Error(
56  iPart) + Variable::trackPhi0Error(jPart));
57  m_TrackOmegaDiffEW = abs(Variable::trackOmega(iPart) - Variable::trackOmega(jPart)) / (Variable::trackOmegaError(
58  iPart) + Variable::trackOmegaError(jPart));
59 }
60 
61 std::vector<float> SelectorMVA::getVariables(Particle* iPart, Particle* jPart)
62 {
63  updateVariables(iPart, jPart);
65 }
66 
68 {
69  updateVariables(iPart, jPart);
70  m_TTree -> Fill();
71 }
72 
74 {
75  if (m_TrainFlag) { //make training data
76  m_TFile = TFile::Open(m_TFileName.c_str(), "RECREATE");
77  m_TTree = new TTree("ntuple", "Training Data for the Curl Tagger MVA");
78 
79  m_TTree -> Branch("PPhi" , &m_PPhi, "PPhi/F");
80  m_TTree -> Branch("ChargeProduct", &m_ChargeProduct, "ChargeProduct/F");
81  m_TTree -> Branch("PtDiffEW", &m_PtDiffEW, "PtDiffEW/F");
82  m_TTree -> Branch("PzDiffEW", &m_PzDiffEW, "PzDiffEW/F");
83  m_TTree -> Branch("TrackD0DiffEW", &m_TrackD0DiffEW, "TrackD0DiffEW/F");
84  m_TTree -> Branch("TrackZ0DiffEW", &m_TrackZ0DiffEW, "TrackZ0DiffEW/F");
85  m_TTree -> Branch("TrackTanLambdaDiffEW", &m_TrackTanLambdaDiffEW, "TrackTanLambdaDiffEW/F");
86  m_TTree -> Branch("TrackPhi0DiffEW", &m_TrackPhi0DiffEW, "TrackPhi0DiffEW/F");
87  m_TTree -> Branch("TrackOmegaDiffEW", &m_TrackOmegaDiffEW, "TrackOmegaDiffEW/F");
88 
89  m_TTree -> Branch("IsCurl", &m_IsCurl, "IsCurl/F");
90 
91  m_target_variable = "IsCurl";
92  m_variables = {"PPhi", "ChargeProduct", "PtDiffEW", "PzDiffEW", "TrackD0DiffEW", "TrackZ0DiffEW", "TrackTanLambdaDiffEW", "TrackPhi0DiffEW", "TrackOmegaDiffEW"};
93 
94  } else { // normal application
95  //load MVA
97  weightfile.getOptions(m_generalOptions);
98  m_expert.load(weightfile);
99  }
100 }
101 
103 {
104  if (m_TrainFlag) {
105  m_TFile -> cd();
106  m_TTree -> Write();
107  m_TFile -> Write();
108  m_TFile -> Close();
109 
110  //train MVA
111  MVA::GeneralOptions generalOptions;
112  generalOptions.m_datafiles = {m_TFileName};
113  generalOptions.m_identifier = m_identifier;
114  generalOptions.m_variables = m_variables;
115  generalOptions.m_target_variable = m_target_variable;
116  generalOptions.m_signal_class = 1;
117  generalOptions.m_weight_variable = ""; // sets all weights to 1 if blank
118 
119  MVA::ROOTDataset dataset(generalOptions);
120 
121  MVA::FastBDTOptions specificOptions;
122  specificOptions.m_nTrees = 1000;
123  //specificOptions.m_shrinkage = 0.10;
124  specificOptions.m_nCuts = 16;
125  specificOptions. m_nLevels = 4;
126 
127  auto teacher = new MVA::FastBDTTeacher(generalOptions, specificOptions);
128  auto weightfile = teacher->train(dataset);
130  //MVA::Weightfile::saveToXMLFile(weightfile, "test.xml");
131  }
132 }
133 
135 {
136  MVA::SingleDataset dataset(m_generalOptions, getVariables(iPart, jPart));
137  return m_expert.apply(dataset)[0];
138 }
Belle2::MVA::FastBDTOptions
Options for the FANN MVA method.
Definition: FastBDT.h:55
Belle2::MVA::GeneralOptions::m_identifier
std::string m_identifier
Identifier containing the finished training.
Definition: Options.h:85
Belle2::CurlTagger::SelectorMVA::m_ChargeProduct
Float_t m_ChargeProduct
charge(p1) * charge(p2)
Definition: SelectorMVA.h:108
Belle2::CurlTagger::SelectorMVA::SelectorMVA
SelectorMVA(bool belleFlag, bool trainFlag)
Constructor.
Definition: SelectorMVA.cc:23
Belle2::CurlTagger::SelectorMVA::m_TFileName
std::string m_TFileName
name of output file for training data
Definition: SelectorMVA.h:74
Belle2::MVA::FastBDTTeacher
Teacher for the FastBDT MVA method.
Definition: FastBDT.h:100
Belle2::CurlTagger::SelectorMVA::initialize
virtual void initialize() override
initialize whatever needs to be initalized (root file etc)
Definition: SelectorMVA.cc:73
Belle2::MVA::Weightfile::saveToDatabase
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
Definition: Weightfile.cc:267
Belle2::MVA::FastBDTExpert::apply
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FastBDT.cc:414
Belle2::MVA::FastBDTExpert::load
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FastBDT.cc:321
Belle2::MVA::GeneralOptions::m_weight_variable
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:92
Belle2::CurlTagger::SelectorMVA::m_TrackPhi0DiffEW
Float_t m_TrackPhi0DiffEW
error weighted track Phi0 difference
Definition: SelectorMVA.h:126
Belle2::CurlTagger::SelectorMVA::finalize
virtual void finalize() override
finalize whatever needs to be finalized (train the MVA)
Definition: SelectorMVA.cc:102
Belle2::CurlTagger::SelectorMVA::m_TrackD0DiffEW
Float_t m_TrackD0DiffEW
error weighted track D0 difference
Definition: SelectorMVA.h:117
Belle2::CurlTagger::SelectorMVA::~SelectorMVA
~SelectorMVA()
Destructor.
Belle2::CurlTagger::SelectorMVA::m_TrackZ0DiffEW
Float_t m_TrackZ0DiffEW
error weighted track Z0 difference
Definition: SelectorMVA.h:120
Belle2::CurlTagger::SelectorMVA::m_TFile
TFile * m_TFile
output file for training data
Definition: SelectorMVA.h:77
Belle2::Particle::getCharge
float getCharge(void) const
Returns particle charge.
Definition: Particle.cc:586
Belle2::MVA::FastBDTOptions::m_nCuts
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:81
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::CurlTagger::SelectorMVA::updateVariables
void updateVariables(Particle *iPart, Particle *jPart)
updates the value of the MVA variable
Definition: SelectorMVA.cc:38
Belle2::MVA::GeneralOptions::m_target_variable
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:91
Belle2::MVA::FastBDTOptions::m_nTrees
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:80
Belle2::MVA::GeneralOptions::m_signal_class
int m_signal_class
Signal class which is used as signal in a classification problem.
Definition: Options.h:90
Belle2::CurlTagger::SelectorMVA::m_expert
MVA::FastBDTExpert m_expert
mva expert
Definition: SelectorMVA.h:86
Belle2::CurlTagger::SelectorMVA::m_PzDiffEW
Float_t m_PzDiffEW
error weighted particle Pz difference
Definition: SelectorMVA.h:114
Belle2::CurlTagger::SelectorMVA::m_TrackOmegaDiffEW
Float_t m_TrackOmegaDiffEW
error weighted track Omega difference
Definition: SelectorMVA.h:129
Belle2::MVA::GeneralOptions::m_variables
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:88
Belle2::CurlTagger::SelectorMVA::collectTrainingInfo
virtual void collectTrainingInfo(Particle *iPart, Particle *jPart) override
collect training data and save to a root file
Definition: SelectorMVA.cc:67
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::Particle::getMomentum
TVector3 getMomentum() const
Returns momentum vector.
Definition: Particle.h:475
Belle2::MVA::SingleDataset
Wraps the data of a single event into a Dataset.
Definition: Dataset.h:136
Belle2::CurlTagger::SelectorMVA::m_PPhi
Float_t m_PPhi
angle between particle momentum vectors
Definition: SelectorMVA.h:105
Belle2::MVA::Weightfile::loadFromDatabase
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
Definition: Weightfile.cc:290
Belle2::CurlTagger::SelectorMVA::m_TrainFlag
bool m_TrainFlag
applying mva or training it
Definition: SelectorMVA.h:71
Belle2::Particle
Class to store reconstructed particles.
Definition: Particle.h:77
Belle2::MVA::GeneralOptions::m_datafiles
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
Definition: Options.h:86
Belle2::CurlTagger::SelectorMVA::m_generalOptions
MVA::GeneralOptions m_generalOptions
mva general options (for the expert)
Definition: SelectorMVA.h:83
Belle2::CurlTagger::SelectorMVA::m_identifier
std::string m_identifier
mva identifier
Definition: SelectorMVA.h:91
Belle2::CurlTagger::SelectorMVA::getResponse
virtual float getResponse(Particle *iPart, Particle *jPart) override
Selector response that this pair of particles come from the same mc/actual particle.
Definition: SelectorMVA.cc:134
Belle2::MVA::ROOTDataset
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
Definition: Dataset.h:349
Belle2::CurlTagger::SelectorMVA::m_PtDiffEW
Float_t m_PtDiffEW
error weighted particle Pt difference
Definition: SelectorMVA.h:111
Belle2::CurlTagger::SelectorMVA::m_variables
std::vector< std::string > m_variables
names of variables used by mva
Definition: SelectorMVA.h:97
Belle2::CurlTagger::SelectorMVA::m_TTree
TTree * m_TTree
training data tree
Definition: SelectorMVA.h:80
Belle2::CurlTagger::SelectorMVA::m_TrackTanLambdaDiffEW
Float_t m_TrackTanLambdaDiffEW
error weighted track tan lambda diff difference
Definition: SelectorMVA.h:123
Belle2::CurlTagger::SelectorMVA::getVariables
virtual std::vector< float > getVariables(Particle *iPart, Particle *jPart) override
returns vector of variables used by this selector.
Definition: SelectorMVA.cc:61
Belle2::CurlTagger::SelectorMVA::m_target_variable
std::string m_target_variable
name of target variable (isCurl)
Definition: SelectorMVA.h:100
Belle2::CurlTagger::SelectorMVA::m_IsCurl
Float_t m_IsCurl
isCurl Truth
Definition: SelectorMVA.h:132