Belle II Software  release-08-01-10
SelectorMVA.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <analysis/modules/CurlTagger/SelectorMVA.h>
10 
11 #include <analysis/variables/TrackVariables.h>
12 #include <analysis/variables/MCTruthVariables.h>
13 #include <analysis/variables/Variables.h>
14 
15 using namespace Belle2;
16 using namespace CurlTagger;
17 
18 SelectorMVA::SelectorMVA(bool belleFlag, bool trainFlag, std::string tFileName)
19 {
20  m_TrainFlag = trainFlag;
21  m_TFileName = tFileName;
22  if (belleFlag) {
23  m_identifier = "CurlTagger_FastBDT_Belle";
24  } else {
25  m_identifier = "CurlTagger_FastBDT_BelleII";
26  }
27 }
28 
29 SelectorMVA::~SelectorMVA() = default;
30 
32 {
33  if (m_TrainFlag) {
34  m_IsCurl = Variable::genParticleIndex(iPart) == Variable::genParticleIndex(jPart);
35  }
36  m_ChargeProduct = iPart->getCharge() * jPart->getCharge();
37 
38  m_PPhi = acos(iPart->getMomentum().Unit().Dot(jPart->getMomentum().Unit()));
39 
40  m_PtDiffEW = abs(Variable::particlePt(iPart) - Variable::particlePt(jPart)) / sqrt(pow(Variable::particlePtErr(
41  iPart), 2) + pow(Variable::particlePtErr(jPart), 2));
42 
43  m_PzDiffEW = abs(Variable::particlePz(iPart) - Variable::particlePz(jPart)) / sqrt(pow(Variable::particlePzErr(
44  iPart), 2) + pow(Variable::particlePzErr(jPart), 2));
45 
46  m_TrackD0DiffEW = abs(Variable::trackD0(iPart) - Variable::trackD0(jPart)) / sqrt(pow(Variable::trackD0Error(
47  iPart), 2) + pow(Variable::trackD0Error(jPart), 2));
48 
49  m_TrackZ0DiffEW = abs(Variable::trackZ0(iPart) - Variable::trackZ0(jPart)) / sqrt(pow(Variable::trackZ0Error(
50  iPart), 2) + pow(Variable::trackZ0Error(jPart), 2));
51 
52  m_TrackTanLambdaDiffEW = abs(Variable::trackTanLambda(iPart) - Variable::trackTanLambda(jPart)) / sqrt(pow(
53  Variable::trackTanLambdaError(
54  iPart), 2) + pow(Variable::trackTanLambdaError(jPart), 2));
55 
56  m_TrackPhi0DiffEW = abs(Variable::trackPhi0(iPart) - Variable::trackPhi0(jPart)) / sqrt(pow(Variable::trackPhi0Error(
57  iPart), 2) + pow(Variable::trackPhi0Error(jPart), 2));
58 
59  m_TrackOmegaDiffEW = abs(Variable::trackOmega(iPart) - Variable::trackOmega(jPart)) / sqrt(pow(Variable::trackOmegaError(
60  iPart), 2) + pow(Variable::trackOmegaError(jPart), 2));
61 }
62 
63 std::vector<float> SelectorMVA::getVariables(Particle* iPart, Particle* jPart)
64 {
65  updateVariables(iPart, jPart);
69 }
70 
72 {
73  updateVariables(iPart, jPart);
74  m_TTree -> Fill();
75 }
76 
78 {
79  if (m_TrainFlag) {
80  //make training data
81  m_TFile = TFile::Open(m_TFileName.c_str(), "RECREATE");
82  m_TTree = new TTree("ntuple", "Training Data for the Curl Tagger MVA");
83 
84  m_TTree -> Branch("PPhi", &m_PPhi, "PPhi/F");
85  m_TTree -> Branch("ChargeProduct", &m_ChargeProduct, "ChargeProduct/F");
86  m_TTree -> Branch("PtDiffEW", &m_PtDiffEW, "PtDiffEW/F");
87  m_TTree -> Branch("PzDiffEW", &m_PzDiffEW, "PzDiffEW/F");
88  m_TTree -> Branch("TrackD0DiffEW", &m_TrackD0DiffEW, "TrackD0DiffEW/F");
89  m_TTree -> Branch("TrackZ0DiffEW", &m_TrackZ0DiffEW, "TrackZ0DiffEW/F");
90  m_TTree -> Branch("TrackTanLambdaDiffEW", &m_TrackTanLambdaDiffEW, "TrackTanLambdaDiffEW/F");
91  m_TTree -> Branch("TrackPhi0DiffEW", &m_TrackPhi0DiffEW, "TrackPhi0DiffEW/F");
92  m_TTree -> Branch("TrackOmegaDiffEW", &m_TrackOmegaDiffEW, "TrackOmegaDiffEW/F");
93 
94  m_TTree -> Branch("IsCurl", &m_IsCurl, "IsCurl/O");
95 
96  m_target_variable = "IsCurl";
97  m_variables = {"PPhi", "ChargeProduct", "PtDiffEW",
98  "PzDiffEW", "TrackD0DiffEW", "TrackZ0DiffEW",
99  "TrackTanLambdaDiffEW", "TrackPhi0DiffEW", "TrackOmegaDiffEW"
100  };
101 
102  } else {
103  // normal application
104  m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
105  MVA::makeSaveForDatabase(m_identifier));
106  (*m_weightfile_representation.get()).addCallback([this]() { initializeMVA();});
107  initializeMVA();
108  }
109 }
110 
112 {
113  std::stringstream ss((*m_weightfile_representation)->m_data);
117 }
118 
119 
121 {
122  if (m_TrainFlag) {
123  return 0.5;
124  }
125  std::string elementIdentfier = "optimal_cut";
126  if (!m_weightfile.containsElement(elementIdentfier)) {
127  B2FATAL("No optimal cut stored in curlTagger MVA payload!");
128  }
129  // require the default value for the compiler to deduce the template class
130  return m_weightfile.getElement(elementIdentfier, 0.5);
131 }
132 
134 {
135  if (m_TrainFlag) {
136  m_TFile -> cd();
137  m_TTree -> Write();
138  m_TFile -> Write();
139  m_TFile -> Close();
140  }
141 }
142 
144 {
145  MVA::SingleDataset dataset(m_generalOptions, getVariables(iPart, jPart));
146  return m_expert.apply(dataset)[0];
147 }
void updateVariables(Particle *iPart, Particle *jPart)
updates the value of the MVA variable
Definition: SelectorMVA.cc:31
virtual std::vector< float > getVariables(Particle *iPart, Particle *jPart) override
returns vector of variables used by this selector.
Definition: SelectorMVA.cc:63
Float_t m_TrackTanLambdaDiffEW
error weighted track tan lambda diff difference
Definition: SelectorMVA.h:127
virtual float getOptimalResponseCut() override
returns optimal cut to use with selector
Definition: SelectorMVA.cc:120
std::vector< std::string > m_variables
names of variables used by mva
Definition: SelectorMVA.h:101
Float_t m_TrackZ0DiffEW
error weighted track Z0 difference
Definition: SelectorMVA.h:124
virtual void initialize() override
initialize whatever needs to be initialized (root file etc)
Definition: SelectorMVA.cc:77
Float_t m_PtDiffEW
error weighted particle Pt difference
Definition: SelectorMVA.h:115
Float_t m_PPhi
angle between particle momentum vectors
Definition: SelectorMVA.h:109
Float_t m_TrackPhi0DiffEW
error weighted track Phi0 difference
Definition: SelectorMVA.h:130
TFile * m_TFile
output file for training data
Definition: SelectorMVA.h:76
Float_t m_PzDiffEW
error weighted particle Pz difference
Definition: SelectorMVA.h:118
MVA::Weightfile m_weightfile
mva weightfile
Definition: SelectorMVA.h:85
Float_t m_TrackD0DiffEW
error weighted track D0 difference
Definition: SelectorMVA.h:121
virtual void finalize() override
finalize whatever needs to be finalized (train the MVA)
Definition: SelectorMVA.cc:133
Float_t m_ChargeProduct
charge(p1) * charge(p2)
Definition: SelectorMVA.h:112
void initializeMVA()
initialize the MVA Expert
Definition: SelectorMVA.cc:111
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
Definition: SelectorMVA.h:82
bool m_TrainFlag
applying mva or training it
Definition: SelectorMVA.h:70
MVA::FastBDTExpert m_expert
mva expert
Definition: SelectorMVA.h:91
virtual void collectTrainingInfo(Particle *iPart, Particle *jPart) override
collect training data and save to a root file
Definition: SelectorMVA.cc:71
std::string m_TFileName
name of output file for training data
Definition: SelectorMVA.h:73
virtual float getResponse(Particle *iPart, Particle *jPart) override
Selector response that this pair of particles come from the same mc/actual particle.
Definition: SelectorMVA.cc:143
Bool_t m_IsCurl
isCurl Truth
Definition: SelectorMVA.h:136
MVA::GeneralOptions m_generalOptions
mva general options (for the expert)
Definition: SelectorMVA.h:88
std::string m_target_variable
name of target variable (isCurl)
Definition: SelectorMVA.h:104
SelectorMVA(bool belleFlag, bool trainFlag, std::string tFileName)
Constructor.
Definition: SelectorMVA.cc:18
Float_t m_TrackOmegaDiffEW
error weighted track Omega difference
Definition: SelectorMVA.h:133
TTree * m_TTree
training data tree
Definition: SelectorMVA.h:79
std::string m_identifier
mva identifier
Definition: SelectorMVA.h:95
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FastBDT.cc:415
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FastBDT.cc:322
Wraps the data of a single event into a Dataset.
Definition: Dataset.h:135
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:151
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
bool containsElement(const std::string &identifier) const
Returns true if given element is stored in the property tree.
Definition: Weightfile.h:160
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
Class to store reconstructed particles.
Definition: Particle.h:75
double getCharge(void) const
Returns particle charge.
Definition: Particle.cc:626
ROOT::Math::XYZVector getMomentum() const
Returns momentum vector.
Definition: Particle.h:526
double sqrt(double a)
sqrt for double
Definition: beamHelpers.h:28
Abstract base class for different kinds of events.