Belle II Software light-2601-hyperion
FlavorTaggerInfoFillerModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <analysis/modules/FlavorTaggerInfoFiller/FlavorTaggerInfoFillerModule.h>
10#include <framework/core/ModuleParam.templateDetails.h>
11#include <analysis/dataobjects/ParticleList.h>
12#include <analysis/dataobjects/FlavorTaggerInfo.h>
13#include <analysis/dataobjects/FlavorTaggerInfoMap.h>
14#include <analysis/VariableManager/Manager.h>
15
16using namespace std;
17using namespace Belle2;
18
19// Register module in the framework
20REG_MODULE(FlavorTaggerInfoFiller);
21
23{
24 //Set module properties
25 setDescription("Creates a new flavorTaggerInfoMap DataObject for the specific methods. Saves there all the relevant information of the flavorTagger.");
27 //Parameter definition
28 addParam("trackLevelParticleLists", m_trackLevelParticleLists, "Used Flavor Tagger trackLevel Categories of the lists ",
29 vector<tuple<string, string>>());
30 addParam("eventLevelParticleLists", m_eventLevelParticleLists, "Used Flavor Tagger eventLevel Categories of the lists ",
31 vector<tuple<string, string, string>>());
32 addParam("FANNmlp", m_FANNmlp, "Sets if FANN Combiner output will be saved or not", false);
33 addParam("TMVAfbdt", m_TMVAfbdt, "Sets if FANN Combiner output will be saved or not", false);
34 addParam("DNNmlp", m_DNNmlp, "Sets if DNN Tagger output will be saved or not", false);
35 addParam("TFLATnn", m_TFLATnn, "Sets if TFLAT Tagger output will be saved or not", false);
36 addParam("qpCategories", m_qpCategories, "Sets if individual categories output will be saved or not", false);
37 addParam("istrueCategories", m_istrueCategories, "Sets if individual MC truth for each category is saved or not", false);
38 addParam("targetProb", m_targetProb, "Sets if individual Categories output will be saved or not", false);
39 addParam("trackPointers", m_trackPointers, "Sets if track pointers to target tracks are saved or not", false);
40
41}
42
47
49{
50 auto* flavorTaggerInfo = m_roe->getRelatedTo<FlavorTaggerInfo>();
51
53
54
55 if (flavorTaggerInfo == nullptr) {
56 B2ERROR("flavorTaggerInfoFiller: FlavorTaggerInfo does not exist");
57 return;
58 }
59
60 flavorTaggerInfo -> setUseModeFlavorTagger("Expert");
61
62 if (m_FANNmlp) {
63 FlavorTaggerInfoMap* infoMapsFANN = flavorTaggerInfo -> getMethodMap("FANN");
64 // For FANN, the output is mapped to be qr
65 float qrCombined = m_eventExtraInfo->getExtraInfo("qrCombinedFANN");
66 if (qrCombined < 1.1 && qrCombined > 1.0) qrCombined = 1.0;
67 if (qrCombined > - 1.1 && qrCombined < -1.0) qrCombined = -1.0;
68 float B0Probability = qrCombined / 2 + 0.5;
69 float B0barProbability = 1 - B0Probability;
70 infoMapsFANN->setQrCombined(qrCombined);
71 infoMapsFANN->setB0Probability(B0Probability);
72 infoMapsFANN->setB0barProbability(B0barProbability);
73 }
74
75 FlavorTaggerInfoMap* infoMapsFBDT = flavorTaggerInfo -> getMethodMap("FBDT");
76
77 if (m_TMVAfbdt) {
78 float B0Probability = m_eventExtraInfo->getExtraInfo("qrCombinedFBDT");
79 float B0barProbability = 1 - B0Probability;
80 float qrCombined = 2 * (B0Probability - 0.5);
81 if (qrCombined < 1.1 && qrCombined > 1.0) qrCombined = 1.0;
82 if (qrCombined > - 1.1 && qrCombined < -1.0) qrCombined = -1.0;
83 infoMapsFBDT->setQrCombined(qrCombined);
84 infoMapsFBDT->setB0Probability(B0Probability);
85 infoMapsFBDT->setB0barProbability(B0barProbability);
86 }
87
88 if (m_DNNmlp) {
89 FlavorTaggerInfoMap* infoMapsDNN = flavorTaggerInfo -> getMethodMap("DNN");
90 const Particle* particle = m_roe->getRelatedFrom<Particle>();
91 float B0Probability = particle->getExtraInfo("dnn_output");
92 float B0barProbability = 1 - B0Probability;
93 float qrCombined = 2 * (B0Probability - 0.5);
94 if (qrCombined < 1.1 && qrCombined > 1.0) qrCombined = 1.0;
95 if (qrCombined > - 1.1 && qrCombined < -1.0) qrCombined = -1.0;
96 infoMapsDNN->setQrCombined(qrCombined);
97 infoMapsDNN->setB0Probability(B0Probability);
98 infoMapsDNN->setB0barProbability(B0barProbability);
99 }
100
101 if (m_TFLATnn) {
102 FlavorTaggerInfoMap* infoMapsTFLAT = flavorTaggerInfo -> getMethodMap("TFLAT");
103 const Particle* particle = m_roe->getRelatedFrom<Particle>();
104 float B0Probability = particle->getExtraInfo("tflat_output");
105 float B0barProbability = 1 - B0Probability;
106 float qrCombined = 2 * (B0Probability - 0.5);
107 if (qrCombined < 1.1 && qrCombined > 1.0) qrCombined = 1.0;
108 if (qrCombined > - 1.1 && qrCombined < -1.0) qrCombined = -1.0;
109 infoMapsTFLAT->setQrCombined(qrCombined);
110 infoMapsTFLAT->setB0Probability(B0Probability);
111 infoMapsTFLAT->setB0barProbability(B0barProbability);
112 }
113
114
115 if (m_targetProb) {
116 for (auto& iTuple : m_trackLevelParticleLists) {
117 string particleListName = get<0>(iTuple);
118 string category = get<1>(iTuple);
119 StoreObjPtr<ParticleList> particleList(particleListName);
120
121
122 if (!particleList.isValid()) {
123 B2INFO("ParticleList " << particleListName << " not found");
124 } else {
125 if (particleList -> getListSize() == 0) {
126 infoMapsFBDT -> setProbTrackLevel(category, 0);
127 if (m_trackPointers) infoMapsFBDT -> setTargetTrackLevel(category, nullptr);
128 } else {
129
130 for (unsigned int i = 0; i < particleList->getListSize(); ++i) {
131 Particle* iParticle = particleList ->getParticle(i);
132 bool hasMaxProb = std::get<bool>(manager.getVariable("hasHighestProbInCat(" + particleListName + "," + "isRightTrack(" + category +
133 "))")->function(iParticle));
134 if (hasMaxProb == 1) {
135 float targetProb = iParticle -> getExtraInfo("isRightTrack(" + category + ")");
136 infoMapsFBDT->setProbTrackLevel(category, targetProb);
137 if (m_trackPointers) {
138 infoMapsFBDT-> setTargetTrackLevel(category, iParticle -> getTrack());
139 }
140 break;
141 }
142 }
143 }
144 }
145 }
146 }
147
148 if (m_qpCategories) {
149
150 for (auto& iTuple : m_eventLevelParticleLists) {
151 string particleListName = get<0>(iTuple);
152 string category = get<1>(iTuple);
153 string qpCategoryVariable = get<2>(iTuple);
154 StoreObjPtr<ParticleList> particleList(particleListName);
155
156 if (!particleList.isValid()) {
157 B2INFO("ParticleList " << particleListName << " not found");
158 } else {
159 if (particleList -> getListSize() == 0) {
160 infoMapsFBDT -> setProbEventLevel(category, 0);
161 infoMapsFBDT -> setQpCategory(category, 0);
162 if (m_istrueCategories and m_mcparticles.isValid()) {
163 infoMapsFBDT -> setHasTrueTarget(category, 0);
164 infoMapsFBDT -> setIsTrueCategory(category, 0);
165 }
166 if (m_trackPointers) infoMapsFBDT -> setTargetEventLevel(category, nullptr);
167 } else {
168
169 for (unsigned int i = 0; i < particleList->getListSize(); ++i) {
170 Particle* iParticle = particleList ->getParticle(i);
171 bool hasMaxProb = std::get<bool>(manager.getVariable("hasHighestProbInCat(" + particleListName + "," + "isRightCategory(" + category
172 + "))")-> function(iParticle));
173 if (hasMaxProb == 1) {
174 float categoryProb = iParticle -> getExtraInfo("isRightCategory(" + category + ")");
175 float qpCategory = std::get<double>(manager.getVariable(qpCategoryVariable)-> function(iParticle));
176 infoMapsFBDT->setProbEventLevel(category, categoryProb);
177 infoMapsFBDT -> setQpCategory(category, qpCategory);
178 if (m_istrueCategories and m_mcparticles.isValid()) {
179 float isTrueTarget = std::get<double>(manager.getVariable("hasTrueTarget(" + category + ")")-> function(nullptr));
180 infoMapsFBDT -> setHasTrueTarget(category, isTrueTarget);
181 float isTrueCategory = std::get<double>(manager.getVariable("isTrueCategory(" + category + ")")-> function(nullptr));
182 infoMapsFBDT -> setIsTrueCategory(category, isTrueCategory);
183 }
184 if (m_trackPointers) {
185 infoMapsFBDT-> setTargetEventLevel(category, iParticle -> getTrack());
186 }
187 break;
188 }
189 }
190 }
191 }
192 }
193 }
194
195}
196
200
std::vector< std::tuple< std::string, std::string > > m_trackLevelParticleLists
Used Flavor Tagger trackLevel Categories of the lists.
StoreObjPtr< RestOfEvent > m_roe
ROE object pointer.
virtual void initialize() override
Initialises the module.
bool m_TFLATnn
Sets if TFLAT tagger output will be saved or not.
virtual void event() override
Method called for each event.
bool m_targetProb
Sets if individual Categories output will be saved or not.
StoreArray< MCParticle > m_mcparticles
StoreArray of MCParticles.
virtual void terminate() override
Write TTree to file, and close file if necessary.
bool m_TMVAfbdt
Sets if FastBDT Combiner output will be saved or not.
bool m_FANNmlp
Sets if FANN Combiner output will be saved or not.
bool m_DNNmlp
Sets if DNN tagger output will be saved or not.
bool m_qpCategories
Sets if individual Categories output will be saved or not.
bool m_istrueCategories
Sets if individual MC thruth for each Category is saved or not.
StoreObjPtr< EventExtraInfo > m_eventExtraInfo
event extra info object pointer
std::vector< std::tuple< std::string, std::string, std::string > > m_eventLevelParticleLists
Used Flavor Tagger eventLevel Categories of the lists.
bool m_trackPointers
Sets if track pointers to target tracks are saved or not.
This class stores the Flavor Tagger information for a specific method and particle filled in the Flav...
void setProbEventLevel(const std::string &category, float probability)
Map filler: Set the category name and the highest category probability for the corresponding category...
void setProbTrackLevel(const std::string &category, float probability)
Map filler: Set the category name and the corresponding highest target track probability.
void setB0Probability(float B0Probability)
Saves the B0Probability output of the Combiner.
void setB0barProbability(float B0barProbability)
Saves the B0barProbability output of the Combiner.
void setQrCombined(float qr)
Saves qr Output of the Combiner.
This class stores the relevant information for the TagV vertex fit, extracted mainly from the Flavor ...
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
Module()
Constructor.
Definition Module.cc:30
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
Class to store reconstructed particles.
Definition Particle.h:76
double getExtraInfo(const std::string &name) const
Return given value if set.
Definition Particle.cc:1374
Type-safe access to single objects in the data store.
Definition StoreObjPtr.h:96
Global list of available variables.
Definition Manager.h:100
static Manager & Instance()
get singleton instance.
Definition Manager.cc:26
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
Abstract base class for different kinds of events.
STL namespace.