Belle II Software development
MVAExpertModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9
10#include <mva/modules/MVAExpert/MVAExpertModule.h>
11
12#include <analysis/DecayDescriptor/DecayDescriptorParticle.h>
13
14#include <analysis/dataobjects/Particle.h>
15#include <analysis/dataobjects/ParticleList.h>
16#include <analysis/dataobjects/ParticleExtraInfoMap.h>
17
18#include <mva/interface/Interface.h>
19
20#include <framework/logging/Logger.h>
21
22
23using namespace Belle2;
24
25REG_MODULE(MVAExpert);
26
28{
29 setDescription("Adds an ExtraInfo to the Particle objects in given ParticleLists which is calculated by an expert defined by a weightfile.");
31
32 std::vector<std::string> empty;
33 addParam("listNames", m_listNames,
34 "Particles from these ParticleLists are used as input. If no name is given the expert is applied to every event once, and one can only use variables which accept nullptr as Particle*. Decay descriptor functionality is supported, which allows to run the module on daughter particles, e.g. Lambda0:my_list -> ^p+ pi-. One has to provide full name of the mother particle list and only one selected daughter is supported.",
35 empty);
36 addParam("extraInfoName", m_extraInfoName,
37 "Name under which the output of the expert is stored in the ExtraInfo of the Particle object. If the expert returns multiple values, the index of the value is appended to the name in the form '_0', '_1', ...");
38 addParam("identifier", m_identifier, "The database identifier which is used to load the weights during the training.");
39 addParam("signalFraction", m_signal_fraction_override,
40 "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
41 addParam("overwriteExistingExtraInfo", m_overwriteExistingExtraInfo,
42 "-1/0/1/2: Overwrite if lower / don't overwrite / overwrite if higher / always overwrite, in case the extra info with given name already exists",
43 2);
44}
45
47{
48 // All specified ParticleLists are required to exist
49 for (auto& name : m_listNames) {
51 bool valid = dd.init(name);
52 if (!valid) {
53 B2ERROR("Decay string " << name << " is invalid.");
54 }
55 const DecayDescriptorParticle* mother = dd.getMother();
56 unsigned int nSelectedDaughters = dd.getSelectionNames().size();
57 if (nSelectedDaughters > 1) {
58 B2ERROR("More than one daughter is selected in the decay string " << name << ".");
59 }
61 list.isRequired();
62 m_targetListNames.push_back(mother->getFullName());
63 m_decaydescriptors.insert(std::make_pair(mother->getFullName(), dd));
64 }
65
66 if (m_listNames.empty()) {
68 extraInfo.isRequired();
69 } else {
71 extraInfo.isRequired();
72 }
73
74 if (not(m_identifier.ends_with(".root") or m_identifier.ends_with(".xml"))) {
75 m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
76 MVA::makeSaveForDatabase(m_identifier));
77 }
79
81}
82
84{
85
87 if (m_weightfile_representation->hasChanged()) {
88 std::stringstream ss((*m_weightfile_representation)->m_data);
89 auto weightfile = MVA::Weightfile::loadFromStream(ss);
90 init_mva(weightfile);
91 }
92 } else {
94 init_mva(weightfile);
95 }
96
97}
98
100{
101
102 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
103 MVA::GeneralOptions general_options;
104 weightfile.getOptions(general_options);
105
106 // Overwrite signal fraction from training
108 weightfile.addSignalFraction(m_signal_fraction_override);
109
110 m_expert = supported_interfaces[general_options.m_method]->getExpert();
111 m_expert->load(weightfile);
112
114 m_feature_variables = manager.getVariables(general_options.m_variables);
115 if (m_feature_variables.size() != general_options.m_variables.size()) {
116 B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
117 }
118
119 std::vector<float> dummy;
120 dummy.resize(m_feature_variables.size(), 0);
121 m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
122 m_nClasses = general_options.m_nClasses;
123}
124
126{
127 for (unsigned int i = 0; i < m_feature_variables.size(); ++i) {
128 auto var_result = m_feature_variables[i]->function(particle);
129 if (std::holds_alternative<double>(var_result)) {
130 m_dataset->m_input[i] = std::get<double>(var_result);
131 } else if (std::holds_alternative<int>(var_result)) {
132 m_dataset->m_input[i] = std::get<int>(var_result);
133 } else if (std::holds_alternative<bool>(var_result)) {
134 m_dataset->m_input[i] = std::get<bool>(var_result);
135 }
136 }
137}
138
139float MVAExpertModule::analyse(const Particle* particle)
140{
141 if (not m_expert) {
142 B2ERROR("MVA Expert is not loaded! I will return 0");
143 return 0.0;
144 }
145 fillDataset(particle);
146 return m_expert->apply(*m_dataset)[0];
147}
148
149std::vector<float> MVAExpertModule::analyseMulticlass(const Particle* particle)
150{
151 if (not m_expert) {
152 B2ERROR("MVA Expert is not loaded! I will return 0");
153 return std::vector<float>(m_nClasses, 0.0);
154 }
155 fillDataset(particle);
156 return m_expert->applyMulticlass(*m_dataset)[0];
157}
158
159void MVAExpertModule::setExtraInfoField(Particle* particle, std::string extraInfoName, float responseValue)
160{
161 if (particle->hasExtraInfo(extraInfoName)) {
162 if (particle->getExtraInfo(extraInfoName) != responseValue) {
164 double current = particle->getExtraInfo(extraInfoName);
166 if (responseValue < current) particle->setExtraInfo(extraInfoName, responseValue);
167 } else if (m_overwriteExistingExtraInfo == 0) {
168 // don't overwrite!
169 } else if (m_overwriteExistingExtraInfo == 1) {
170 if (responseValue > current) particle->setExtraInfo(extraInfoName, responseValue);
171 } else if (m_overwriteExistingExtraInfo == 2) {
172 particle->setExtraInfo(extraInfoName, responseValue);
173 } else {
174 B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo << "'.");
175 }
176 }
177 } else {
178 particle->addExtraInfo(extraInfoName, responseValue);
179 }
180}
181
182void MVAExpertModule::setEventExtraInfoField(StoreObjPtr<EventExtraInfo> eventExtraInfo, std::string extraInfoName,
183 float responseValue)
184{
185 if (eventExtraInfo->hasExtraInfo(extraInfoName)) {
187 double current = eventExtraInfo->getExtraInfo(extraInfoName);
189 if (responseValue < current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
190 } else if (m_overwriteExistingExtraInfo == 0) {
191 // don't overwrite!
192 } else if (m_overwriteExistingExtraInfo == 1) {
193 if (responseValue > current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
194 } else if (m_overwriteExistingExtraInfo == 2) {
195 eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
196 } else {
197 B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo << "'.");
198 }
199 } else {
200 eventExtraInfo->addExtraInfo(extraInfoName, responseValue);
201 }
202}
203
205{
206 for (auto& listName : m_targetListNames) {
207 StoreObjPtr<ParticleList> list(listName);
208 // Calculate target Value for Particles
209 for (unsigned i = 0; i < list->getListSize(); ++i) {
210 auto dd = m_decaydescriptors[listName];
211 unsigned int nSelectedDaughters = dd.getSelectionNames().size();
212 const Particle* particle = (nSelectedDaughters > 0) ? dd.getSelectionParticles(list->getParticle(i))[0] : list->getParticle(i);
213 if (m_nClasses == 2) {
214 float responseValue = analyse(particle);
215 setExtraInfoField(m_particles[particle->getArrayIndex()], m_extraInfoName, responseValue);
216 } else if (m_nClasses > 2) {
217 std::vector<float> responseValues = analyseMulticlass(particle);
218 if (responseValues.size() != m_nClasses) {
219 B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues.size() <<
220 ") does not match the declared number of classes (" << m_nClasses << ").");
221 }
222 for (unsigned int iClass = 0; iClass < m_nClasses; iClass++) {
223 setExtraInfoField(m_particles[particle->getArrayIndex()], m_extraInfoName + "_" + std::to_string(iClass), responseValues[iClass]);
224 }
225 } else {
226 B2ERROR("Received a value of " << m_nClasses <<
227 " for the number of classes considered by the MVA Expert. This value should be >=2.");
228 }
229 }
230 }
231 if (m_listNames.empty()) {
232 StoreObjPtr<EventExtraInfo> eventExtraInfo;
233 if (not eventExtraInfo.isValid())
234 eventExtraInfo.create();
235
236 if (m_nClasses == 2) {
237 float responseValue = analyse(nullptr);
238 setEventExtraInfoField(eventExtraInfo, m_extraInfoName, responseValue);
239 } else if (m_nClasses > 2) {
240 std::vector<float> responseValues = analyseMulticlass(nullptr);
241 if (responseValues.size() != m_nClasses) {
242 B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues.size() <<
243 ") does not match the declared number of classes (" << m_nClasses << ").");
244 }
245 for (unsigned int iClass = 0; iClass < m_nClasses; iClass++) {
246 setEventExtraInfoField(eventExtraInfo, m_extraInfoName + "_" + std::to_string(iClass), responseValues[iClass]);
247 }
248 } else {
249 B2ERROR("Received a value of " << m_nClasses <<
250 " for the number of classes considered by the MVA Expert. This value should be >=2.");
251 }
252 }
253}
254
256{
257 m_expert.reset();
258 m_dataset.reset();
259
262 B2WARNING("The extraInfo " << m_extraInfoName <<
263 " has already been set! It was overwritten by this module if the new value was lower than the previous!");
264 } else if (m_overwriteExistingExtraInfo == 0) {
265 B2WARNING("The extraInfo " << m_extraInfoName <<
266 " has already been set! The original value was kept and this module did not overwrite it!");
267 } else if (m_overwriteExistingExtraInfo == 1) {
268 B2WARNING("The extraInfo " << m_extraInfoName <<
269 " has already been set! It was overwritten by this module if the new value was higher than the previous!");
270 } else if (m_overwriteExistingExtraInfo == 2) {
271 B2WARNING("The extraInfo " << m_extraInfoName << " has already been set! It was overwritten by this module!");
272 }
273 }
274}
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition DataStore.h:59
Represents a particle in the DecayDescriptor.
std::string getFullName() const
returns the full name of the particle full_name = name:label
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
bool init(const std::string &str)
Initialise the DecayDescriptor from given string.
std::vector< std::string > getSelectionNames()
Return list of human readable names of selected particles.
const DecayDescriptorParticle * getMother() const
return mother.
std::unordered_map< std::string, DecayDescriptor > m_decaydescriptors
Decay descriptor of decays to look for.
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
std::vector< const Variable::Manager::Var * > m_feature_variables
Pointers to the feature variables.
virtual void initialize() override
Initialize the module.
virtual void event() override
Called for each event.
void setEventExtraInfoField(StoreObjPtr< EventExtraInfo >, std::string, float)
Set the event extra info field.
StoreArray< Particle > m_particles
StoreArray of Particles.
virtual void terminate() override
Called at the end of the event processing.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::vector< std::string > m_targetListNames
input particle list names after decay descriptor
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
std::vector< std::string > m_listNames
input particle list names
virtual void beginRun() override
Called at the beginning of a new run.
float analyse(const Particle *)
Calculates expert output for given Particle pointer.
bool m_existGivenExtraInfo
check if the given extraInfo is already defined.
int m_overwriteExistingExtraInfo
-1/0/1/2: overwrite if lower/ don't overwrite / overwrite if higher/ always overwrite,...
std::vector< float > analyseMulticlass(const Particle *)
Calculates expert output for given Particle pointer.
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::string m_extraInfoName
Name under which the SignalProbability is stored in the extraInfo of the Particle object.
unsigned int m_nClasses
number of classes (~outputs) of the current MVA Expert.
void setExtraInfoField(Particle *, std::string, float)
Set the extra info field.
void fillDataset(const Particle *)
Evaluate the variables and fill the Dataset to be used by the expert.
std::string m_identifier
weight-file
static void initSupportedInterfaces()
Static function which initializes all supported interfaces, has to be called once before getSupported...
Definition Interface.cc:46
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition Interface.h:53
General options which are shared by all MVA trainings.
Definition Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
Module()
Constructor.
Definition Module.cc:30
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
Class to store reconstructed particles.
Definition Particle.h:76
Type-safe access to single objects in the data store.
Definition StoreObjPtr.h:96
Global list of available variables.
Definition Manager.h:100
static Manager & Instance()
get singleton instance.
Definition Manager.cc:26
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
Abstract base class for different kinds of events.