Belle II Software development
SelectorMVA.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <analysis/modules/CurlTagger/SelectorMVA.h>
10
11#include <analysis/dataobjects/Particle.h>
12
13#include <analysis/variables/TrackVariables.h>
14#include <analysis/variables/MCTruthVariables.h>
15#include <analysis/variables/Variables.h>
16
17#include <cmath>
18
19using namespace Belle2;
20using namespace CurlTagger;
21
22SelectorMVA::SelectorMVA(bool belleFlag, bool trainFlag, std::string tFileName)
23{
24 m_TrainFlag = trainFlag;
25 m_TFileName = tFileName;
26 if (belleFlag) {
27 m_identifier = "CurlTagger_FastBDT_Belle";
28 } else {
29 m_identifier = "CurlTagger_FastBDT_BelleII";
30 }
31}
32
34
36{
37 if (m_TrainFlag) {
38 m_IsCurl = Variable::genParticleIndex(iPart) == Variable::genParticleIndex(jPart);
39 }
40 m_ChargeProduct = iPart->getCharge() * jPart->getCharge();
41
42 m_PPhi = acos(iPart->getMomentum().Unit().Dot(jPart->getMomentum().Unit()));
43
44 m_PtDiffEW = std::abs(Variable::particlePt(iPart) - Variable::particlePt(jPart)) / sqrt(pow(Variable::particlePtErr(
45 iPart), 2) + pow(Variable::particlePtErr(jPart), 2));
46
47 m_PzDiffEW = std::abs(Variable::particlePz(iPart) - Variable::particlePz(jPart)) / sqrt(pow(Variable::particlePzErr(
48 iPart), 2) + pow(Variable::particlePzErr(jPart), 2));
49
50 m_TrackD0DiffEW = std::abs(Variable::trackD0(iPart) - Variable::trackD0(jPart)) / sqrt(pow(Variable::trackD0Error(
51 iPart), 2) + pow(Variable::trackD0Error(jPart), 2));
52
53 m_TrackZ0DiffEW = std::abs(Variable::trackZ0(iPart) - Variable::trackZ0(jPart)) / sqrt(pow(Variable::trackZ0Error(
54 iPart), 2) + pow(Variable::trackZ0Error(jPart), 2));
55
56 m_TrackTanLambdaDiffEW = std::abs(Variable::trackTanLambda(iPart) - Variable::trackTanLambda(jPart)) / sqrt(pow(
57 Variable::trackTanLambdaError(
58 iPart), 2) + pow(Variable::trackTanLambdaError(jPart), 2));
59
60 m_TrackPhi0DiffEW = std::abs(Variable::trackPhi0(iPart) - Variable::trackPhi0(jPart)) / sqrt(pow(Variable::trackPhi0Error(
61 iPart), 2) + pow(Variable::trackPhi0Error(jPart), 2));
62
63 m_TrackOmegaDiffEW = std::abs(Variable::trackOmega(iPart) - Variable::trackOmega(jPart)) / sqrt(pow(Variable::trackOmegaError(
64 iPart), 2) + pow(Variable::trackOmegaError(jPart), 2));
65}
66
74
76{
77 updateVariables(iPart, jPart);
78 m_TTree -> Fill();
79}
80
82{
83 if (m_TrainFlag) {
84 //make training data
85 m_TFile = TFile::Open(m_TFileName.c_str(), "RECREATE");
86 m_TTree = new TTree("ntuple", "Training Data for the Curl Tagger MVA");
87
88 m_TTree -> Branch("PPhi", &m_PPhi, "PPhi/F");
89 m_TTree -> Branch("ChargeProduct", &m_ChargeProduct, "ChargeProduct/F");
90 m_TTree -> Branch("PtDiffEW", &m_PtDiffEW, "PtDiffEW/F");
91 m_TTree -> Branch("PzDiffEW", &m_PzDiffEW, "PzDiffEW/F");
92 m_TTree -> Branch("TrackD0DiffEW", &m_TrackD0DiffEW, "TrackD0DiffEW/F");
93 m_TTree -> Branch("TrackZ0DiffEW", &m_TrackZ0DiffEW, "TrackZ0DiffEW/F");
94 m_TTree -> Branch("TrackTanLambdaDiffEW", &m_TrackTanLambdaDiffEW, "TrackTanLambdaDiffEW/F");
95 m_TTree -> Branch("TrackPhi0DiffEW", &m_TrackPhi0DiffEW, "TrackPhi0DiffEW/F");
96 m_TTree -> Branch("TrackOmegaDiffEW", &m_TrackOmegaDiffEW, "TrackOmegaDiffEW/F");
97
98 m_TTree -> Branch("IsCurl", &m_IsCurl, "IsCurl/O");
99
100 m_target_variable = "IsCurl";
101 m_variables = {"PPhi", "ChargeProduct", "PtDiffEW",
102 "PzDiffEW", "TrackD0DiffEW", "TrackZ0DiffEW",
103 "TrackTanLambdaDiffEW", "TrackPhi0DiffEW", "TrackOmegaDiffEW"
104 };
105
106 } else {
107 // normal application
108 m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
109 MVA::makeSaveForDatabase(m_identifier));
110 (*m_weightfile_representation.get()).addCallback([this]() { initializeMVA();});
112 }
113}
114
116{
117 std::stringstream ss((*m_weightfile_representation)->m_data);
119 m_weightfile.getOptions(m_generalOptions);
121}
122
123
125{
126 if (m_TrainFlag) {
127 return 0.5;
128 }
129 std::string elementIdentfier = "optimal_cut";
130 if (!m_weightfile.containsElement(elementIdentfier)) {
131 B2FATAL("No optimal cut stored in curlTagger MVA payload!");
132 }
133 // require the default value for the compiler to deduce the template class
134 return m_weightfile.getElement(elementIdentfier, 0.5);
135}
136
138{
139 if (m_TrainFlag) {
140 m_TFile -> cd();
141 m_TTree -> Write();
142 m_TFile -> Write();
143 m_TFile -> Close();
144 }
145}
146
148{
149 MVA::SingleDataset dataset(m_generalOptions, getVariables(iPart, jPart));
150 return m_expert.apply(dataset)[0];
151}
void updateVariables(Particle *iPart, Particle *jPart)
updates the value of the MVA variable
virtual std::vector< float > getVariables(Particle *iPart, Particle *jPart) override
returns vector of variables used by this selector.
Float_t m_TrackTanLambdaDiffEW
error weighted track tan lambda diff difference
virtual float getOptimalResponseCut() override
returns optimal cut to use with selector
std::vector< std::string > m_variables
names of variables used by mva
Float_t m_TrackZ0DiffEW
error weighted track Z0 difference
virtual void initialize() override
initialize whatever needs to be initialized (root file etc)
Float_t m_PtDiffEW
error weighted particle Pt difference
Float_t m_PPhi
angle between particle momentum vectors
Float_t m_TrackPhi0DiffEW
error weighted track Phi0 difference
TFile * m_TFile
output file for training data
Definition SelectorMVA.h:76
Float_t m_PzDiffEW
error weighted particle Pz difference
MVA::Weightfile m_weightfile
mva weightfile
Definition SelectorMVA.h:85
Float_t m_TrackD0DiffEW
error weighted track D0 difference
virtual void finalize() override
finalize whatever needs to be finalized (train the MVA)
Float_t m_ChargeProduct
charge(p1) * charge(p2)
void initializeMVA()
initialize the MVA Expert
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
Definition SelectorMVA.h:82
bool m_TrainFlag
applying mva or training it
Definition SelectorMVA.h:70
MVA::FastBDTExpert m_expert
mva expert
Definition SelectorMVA.h:91
virtual void collectTrainingInfo(Particle *iPart, Particle *jPart) override
collect training data and save to a root file
std::string m_TFileName
name of output file for training data
Definition SelectorMVA.h:73
virtual float getResponse(Particle *iPart, Particle *jPart) override
Selector response that this pair of particles come from the same mc/actual particle.
MVA::GeneralOptions m_generalOptions
mva general options (for the expert)
Definition SelectorMVA.h:88
std::string m_target_variable
name of target variable (isCurl)
SelectorMVA(bool belleFlag, bool trainFlag, std::string tFileName)
Constructor.
Float_t m_TrackOmegaDiffEW
error weighted track Omega difference
TTree * m_TTree
training data tree
Definition SelectorMVA.h:79
std::string m_identifier
mva identifier
Definition SelectorMVA.h:95
Wraps the data of a single event into a Dataset.
Definition Dataset.h:135
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Class to store reconstructed particles.
Definition Particle.h:76
double getCharge(void) const
Returns particle charge.
Definition Particle.cc:653
ROOT::Math::XYZVector getMomentum() const
Returns momentum vector.
Definition Particle.h:580
double sqrt(double a)
sqrt for double
Definition beamHelpers.h:28
Abstract base class for different kinds of events.