Belle II Software development
MVAExpertModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9
10#include <mva/modules/MVAExpert/MVAExpertModule.h>
11
12#include <analysis/DecayDescriptor/DecayDescriptorParticle.h>
13
14#include <analysis/dataobjects/Particle.h>
15#include <analysis/dataobjects/ParticleList.h>
16#include <analysis/dataobjects/ParticleExtraInfoMap.h>
17#include <framework/dataobjects/EventExtraInfo.h>
18
19#include <mva/interface/Interface.h>
20
21#include <boost/algorithm/string/predicate.hpp>
22
23#include <framework/logging/Logger.h>
24
25
26using namespace Belle2;
27
29
31{
32 setDescription("Adds an ExtraInfo to the Particle objects in given ParticleLists which is calcuated by an expert defined by a weightfile.");
34
35 std::vector<std::string> empty;
36 addParam("listNames", m_listNames,
37 "Particles from these ParticleLists are used as input. If no name is given the expert is applied to every event once, and one can only use variables which accept nullptr as Particle*. Decay descriptor functionality is supported, which allows to run the module on daughter particles, e.g. Lambda0:my_list -> ^p+ pi-. One has to provide full name of the mother particle list and only one selected daughter is supported.",
38 empty);
39 addParam("extraInfoName", m_extraInfoName,
40 "Name under which the output of the expert is stored in the ExtraInfo of the Particle object. If the expert returns multiple values, the index of the value is appended to the name in the form '_0', '_1', ...");
41 addParam("identifier", m_identifier, "The database identifier which is used to load the weights during the training.");
42 addParam("signalFraction", m_signal_fraction_override,
43 "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
44 addParam("overwriteExistingExtraInfo", m_overwriteExistingExtraInfo,
45 "-1/0/1/2: Overwrite if lower / don't overwrite / overwrite if higher / always overwrite, in case the extra info with given name already exists",
46 2);
47}
48
50{
51 // All specified ParticleLists are required to exist
52 for (auto& name : m_listNames) {
54 bool valid = dd.init(name);
55 if (!valid) {
56 B2ERROR("Decay string " << name << " is invalid.");
57 }
58 const DecayDescriptorParticle* mother = dd.getMother();
59 unsigned int nSelectedDaughters = dd.getSelectionNames().size();
60 if (nSelectedDaughters > 1) {
61 B2ERROR("More than one daughter is selected in the decay string " << name << ".");
62 }
64 list.isRequired();
65 m_targetListNames.push_back(mother->getFullName());
66 m_decaydescriptors.insert(std::make_pair(mother->getFullName(), dd));
67 }
68
69 if (m_listNames.empty()) {
71 extraInfo.registerInDataStore();
72 } else {
74 extraInfo.registerInDataStore();
75 }
76
77 if (not(boost::ends_with(m_identifier, ".root") or boost::ends_with(m_identifier, ".xml"))) {
78 m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
79 MVA::makeSaveForDatabase(m_identifier));
80 }
82
84}
85
87{
88
90 if (m_weightfile_representation->hasChanged()) {
91 std::stringstream ss((*m_weightfile_representation)->m_data);
92 auto weightfile = MVA::Weightfile::loadFromStream(ss);
93 init_mva(weightfile);
94 }
95 } else {
97 init_mva(weightfile);
98 }
99
100}
101
103{
104
105 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
106 MVA::GeneralOptions general_options;
107 weightfile.getOptions(general_options);
108
109 // Overwrite signal fraction from training
112
113 m_expert = supported_interfaces[general_options.m_method]->getExpert();
114 m_expert->load(weightfile);
115
117 m_feature_variables = manager.getVariables(general_options.m_variables);
118 if (m_feature_variables.size() != general_options.m_variables.size()) {
119 B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
120 }
121
122 std::vector<float> dummy;
123 dummy.resize(m_feature_variables.size(), 0);
124 m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
125 m_nClasses = general_options.m_nClasses;
126}
127
129{
130 for (unsigned int i = 0; i < m_feature_variables.size(); ++i) {
131 auto var_result = m_feature_variables[i]->function(particle);
132 if (std::holds_alternative<double>(var_result)) {
133 m_dataset->m_input[i] = std::get<double>(var_result);
134 } else if (std::holds_alternative<int>(var_result)) {
135 m_dataset->m_input[i] = std::get<int>(var_result);
136 } else if (std::holds_alternative<bool>(var_result)) {
137 m_dataset->m_input[i] = std::get<bool>(var_result);
138 }
139 }
140}
141
142float MVAExpertModule::analyse(const Particle* particle)
143{
144 if (not m_expert) {
145 B2ERROR("MVA Expert is not loaded! I will return 0");
146 return 0.0;
147 }
148 fillDataset(particle);
149 return m_expert->apply(*m_dataset)[0];
150}
151
152std::vector<float> MVAExpertModule::analyseMulticlass(const Particle* particle)
153{
154 if (not m_expert) {
155 B2ERROR("MVA Expert is not loaded! I will return 0");
156 return std::vector<float>(m_nClasses, 0.0);
157 }
158 fillDataset(particle);
159 return m_expert->applyMulticlass(*m_dataset)[0];
160}
161
162void MVAExpertModule::setExtraInfoField(Particle* particle, std::string extraInfoName, float responseValue)
163{
164 if (particle->hasExtraInfo(extraInfoName)) {
165 if (particle->getExtraInfo(extraInfoName) != responseValue) {
167 double current = particle->getExtraInfo(extraInfoName);
169 if (responseValue < current) particle->setExtraInfo(extraInfoName, responseValue);
170 } else if (m_overwriteExistingExtraInfo == 0) {
171 // don't overwrite!
172 } else if (m_overwriteExistingExtraInfo == 1) {
173 if (responseValue > current) particle->setExtraInfo(extraInfoName, responseValue);
174 } else if (m_overwriteExistingExtraInfo == 2) {
175 particle->setExtraInfo(extraInfoName, responseValue);
176 } else {
177 B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo << "'.");
178 }
179 }
180 } else {
181 particle->addExtraInfo(extraInfoName, responseValue);
182 }
183}
184
185void MVAExpertModule::setEventExtraInfoField(StoreObjPtr<EventExtraInfo> eventExtraInfo, std::string extraInfoName,
186 float responseValue)
187{
188 if (eventExtraInfo->hasExtraInfo(extraInfoName)) {
190 double current = eventExtraInfo->getExtraInfo(extraInfoName);
192 if (responseValue < current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
193 } else if (m_overwriteExistingExtraInfo == 0) {
194 // don't overwrite!
195 } else if (m_overwriteExistingExtraInfo == 1) {
196 if (responseValue > current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
197 } else if (m_overwriteExistingExtraInfo == 2) {
198 eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
199 } else {
200 B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo << "'.");
201 }
202 } else {
203 eventExtraInfo->addExtraInfo(extraInfoName, responseValue);
204 }
205}
206
208{
209 for (auto& listName : m_targetListNames) {
210 StoreObjPtr<ParticleList> list(listName);
211 // Calculate target Value for Particles
212 for (unsigned i = 0; i < list->getListSize(); ++i) {
213 auto dd = m_decaydescriptors[listName];
214 unsigned int nSelectedDaughters = dd.getSelectionNames().size();
215 const Particle* particle = (nSelectedDaughters > 0) ? dd.getSelectionParticles(list->getParticle(i))[0] : list->getParticle(i);
216 if (m_nClasses == 2) {
217 float responseValue = analyse(particle);
218 setExtraInfoField(m_particles[particle->getArrayIndex()], m_extraInfoName, responseValue);
219 } else if (m_nClasses > 2) {
220 std::vector<float> responseValues = analyseMulticlass(particle);
221 if (responseValues.size() != m_nClasses) {
222 B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues.size() <<
223 ") does not match the declared number of classes (" << m_nClasses << ").");
224 }
225 for (unsigned int iClass = 0; iClass < m_nClasses; iClass++) {
226 setExtraInfoField(m_particles[particle->getArrayIndex()], m_extraInfoName + "_" + std::to_string(iClass), responseValues[iClass]);
227 }
228 } else {
229 B2ERROR("Received a value of " << m_nClasses <<
230 " for the number of classes considered by the MVA Expert. This value should be >=2.");
231 }
232 }
233 }
234 if (m_listNames.empty()) {
235 StoreObjPtr<EventExtraInfo> eventExtraInfo;
236 if (not eventExtraInfo.isValid())
237 eventExtraInfo.create();
238
239 if (m_nClasses == 2) {
240 float responseValue = analyse(nullptr);
241 setEventExtraInfoField(eventExtraInfo, m_extraInfoName, responseValue);
242 } else if (m_nClasses > 2) {
243 std::vector<float> responseValues = analyseMulticlass(nullptr);
244 if (responseValues.size() != m_nClasses) {
245 B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues.size() <<
246 ") does not match the declared number of classes (" << m_nClasses << ").");
247 }
248 for (unsigned int iClass = 0; iClass < m_nClasses; iClass++) {
249 setEventExtraInfoField(eventExtraInfo, m_extraInfoName + "_" + std::to_string(iClass), responseValues[iClass]);
250 }
251 } else {
252 B2ERROR("Received a value of " << m_nClasses <<
253 " for the number of classes considered by the MVA Expert. This value should be >=2.");
254 }
255 }
256}
257
259{
260 m_expert.reset();
261 m_dataset.reset();
262
265 B2WARNING("The extraInfo " << m_extraInfoName <<
266 " has already been set! It was overwritten by this module if the new value was lower than the previous!");
267 } else if (m_overwriteExistingExtraInfo == 0) {
268 B2WARNING("The extraInfo " << m_extraInfoName <<
269 " has already been set! The original value was kept and this module did not overwrite it!");
270 } else if (m_overwriteExistingExtraInfo == 1) {
271 B2WARNING("The extraInfo " << m_extraInfoName <<
272 " has already been set! It was overwritten by this module if the new value was higher than the previous!");
273 } else if (m_overwriteExistingExtraInfo == 2) {
274 B2WARNING("The extraInfo " << m_extraInfoName << " has already been set! It was overwritten by this module!");
275 }
276 }
277}
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition: DataStore.h:59
Represents a particle in the DecayDescriptor.
std::string getFullName() const
returns the full name of the particle full_name = name:label
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
bool init(const std::string &str)
Initialise the DecayDescriptor from given string.
std::vector< std::string > getSelectionNames()
Return list of human readable names of selected particles.
const DecayDescriptorParticle * getMother() const
return mother.
std::unordered_map< std::string, DecayDescriptor > m_decaydescriptors
Decay descriptor of decays to look for.
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
std::vector< const Variable::Manager::Var * > m_feature_variables
Pointers to the feature variables.
virtual void initialize() override
Initialize the module.
virtual void event() override
Called for each event.
void setEventExtraInfoField(StoreObjPtr< EventExtraInfo >, std::string, float)
Set the event extra info field.
StoreArray< Particle > m_particles
StoreArray of Particles.
virtual void terminate() override
Called at the end of the event processing.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::vector< std::string > m_targetListNames
input particle list names after decay descriptor
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
std::vector< std::string > m_listNames
input particle list names
virtual void beginRun() override
Called at the beginning of a new run.
MVAExpertModule()
Constructor.
float analyse(const Particle *)
Calculates expert output for given Particle pointer.
bool m_existGivenExtraInfo
check if the given extraInfo is already defined.
int m_overwriteExistingExtraInfo
-1/0/1/2: overwrite if lower/ don't overwrite / overwrite if higher/ always overwrite,...
std::vector< float > analyseMulticlass(const Particle *)
Calculates expert output for given Particle pointer.
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::string m_extraInfoName
Name under which the SignalProbability is stored in the extraInfo of the Particle object.
unsigned int m_nClasses
number of classes (~outputs) of the current MVA Expert.
void setExtraInfoField(Particle *, std::string, float)
Set the extra info field.
void fillDataset(const Particle *)
Evaluate the variables and fill the Dataset to be used by the expert.
std::string m_identifier
weight-file
Class to interact with the MVA package, based on class with same name in CDC package.
Definition: MVAExpert.h:33
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Class to store reconstructed particles.
Definition: Particle.h:75
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:96
Global list of available variables.
Definition: Manager.h:101
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
Abstract base class for different kinds of events.