Belle II Software light-2406-ragdoll
MVAMultipleExpertsModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9
10#include <mva/modules/MVAExpert/MVAMultipleExpertsModule.h>
11
12#include <analysis/dataobjects/Particle.h>
13#include <analysis/dataobjects/ParticleList.h>
14#include <analysis/dataobjects/ParticleExtraInfoMap.h>
15#include <framework/dataobjects/EventExtraInfo.h>
16
17#include <mva/interface/Interface.h>
18
19#include <boost/algorithm/string/predicate.hpp>
20
21#include <framework/logging/Logger.h>
22
23
24using namespace Belle2;
25
26REG_MODULE(MVAMultipleExperts);
27
29{
30 setDescription("Adds ExtraInfos to the Particle objects in given ParticleLists which is calcuated by multiple experts defined by the given weightfiles.");
32
33 std::vector<std::string> empty;
34 addParam("listNames", m_listNames,
35 "Particles from these ParticleLists are used as input. If no name is given the experts are applied to every event once, and one can only use variables which accept nullptr as Particle*",
36 empty);
37 addParam("extraInfoNames", m_extraInfoNames,
38 "Names under which the output of the experts is stored in the ExtraInfo of the Particle object.");
39 addParam("identifiers", m_identifiers, "The database identifiers which is used to load the weights during the training.");
40 addParam("signalFraction", m_signal_fraction_override,
41 "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
42 std::vector<int> empty_vec;
43 addParam("overwriteExistingExtraInfo", m_overwriteExistingExtraInfo,
44 "If true, when the given extraInfo has already defined, the old extraInfo value is overwritten. If false, the original value is kept.",
45 empty_vec);
46}
47
49{
50 // All specified ParticleLists are required to exist
51 for (auto& name : m_listNames) {
53 list.isRequired();
54 }
55
56 if (m_listNames.empty()) {
58 extraInfo.registerInDataStore();
59 } else {
61 extraInfo.registerInDataStore();
62 }
63
64 if (m_extraInfoNames.size() != m_identifiers.size()) {
65 B2FATAL("The number of given m_extraInfoNames is not equal to the number of m_identifiers. The output the ith method in m_identifiers is saved as extraInfo under the ith name in m_extraInfoNames! Set also different names for each method!");
66 }
67
69 m_experts.resize(m_identifiers.size());
71 m_datasets.resize(m_identifiers.size());
72 m_nClasses.resize(m_identifiers.size());
73 // if the size of m_overwriteExistingExtraInfo is smaller than that of m_identifiers, 2 will be filled.
75 m_existGivenExtraInfo.resize(m_identifiers.size(), false);
76
77 for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
78 if (not(boost::ends_with(m_identifiers[i], ".root") or boost::ends_with(m_identifiers[i], ".xml"))) {
79 m_weightfile_representations[i] = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
80 MVA::makeSaveForDatabase(m_identifiers[i]));
81 }
82 }
83
85}
86
88{
89
90 if (!m_weightfile_representations.empty()) {
91 for (unsigned int i = 0; i < m_weightfile_representations.size(); ++i) {
93 if (m_weightfile_representations[i]->hasChanged()) {
94 std::stringstream ss((*m_weightfile_representations[i])->m_data);
95 auto weightfile = MVA::Weightfile::loadFromStream(ss);
96 init_mva(weightfile, i);
97 }
98 } else {
99 auto weightfile = MVA::Weightfile::loadFromFile(m_identifiers[i]);
100 init_mva(weightfile, i);
101 }
102 }
103
104 } else B2FATAL("No m_identifiers given. At least one is needed!");
105}
106
108{
109
110 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
112
113
114 MVA::GeneralOptions general_options;
115 weightfile.getOptions(general_options);
116
117 // Overwrite signal fraction from training
120
121 m_experts[i] = supported_interfaces[general_options.m_method]->getExpert();
122 m_experts[i]->load(weightfile);
123
124
125 m_individual_feature_variables[i] = manager.getVariables(general_options.m_variables);
126 if (m_individual_feature_variables[i].size() != general_options.m_variables.size()) {
127 B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
128 }
129
130 for (auto& iVariable : m_individual_feature_variables[i]) {
131 if (m_feature_variables.find(iVariable) == m_feature_variables.end()) {
132 m_feature_variables.insert(std::pair<const Variable::Manager::Var*, float>(iVariable, 0));
133 }
134 }
135
136 std::vector<float> dummy;
137 dummy.resize(m_individual_feature_variables[i].size(), 0);
138 m_datasets[i] = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
139
140 m_nClasses[i] = general_options.m_nClasses;
141
142}
143
145{
146 for (auto const& iVariable : m_feature_variables) {
147 if (iVariable.first->variabletype == Variable::Manager::VariableDataType::c_double) {
148 m_feature_variables[iVariable.first] = std::get<double>(iVariable.first->function(particle));
149 } else if (iVariable.first->variabletype == Variable::Manager::VariableDataType::c_int) {
150 m_feature_variables[iVariable.first] = std::get<int>(iVariable.first->function(particle));
151 } else if (iVariable.first->variabletype == Variable::Manager::VariableDataType::c_bool) {
152 m_feature_variables[iVariable.first] = std::get<bool>(iVariable.first->function(particle));
153 }
154 }
155
156 for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
157 for (unsigned int j = 0; j < m_individual_feature_variables[i].size(); ++j) {
159 }
160 }
161}
162
163std::vector<std::vector<float>> MVAMultipleExpertsModule::analyse(Particle* particle)
164{
165 std::vector<std::vector<float>> responseValues;
166 fillDatasets(particle);
167
168 for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
169 if (m_nClasses[i] == 2) {
170 responseValues.push_back({m_experts[i]->apply(*m_datasets[i])[0],});
171 } else if (m_nClasses[i] > 2) {
172 responseValues.push_back(m_experts[i]->applyMulticlass(*m_datasets[i])[0]);
173 } else {
174 B2ERROR("Received a value of " << m_nClasses[i] <<
175 " for the number of classes considered by the MVA Expert. This value should be >=2.");
176 }
177 }
178 return responseValues;
179}
180
181void MVAMultipleExpertsModule::setExtraInfoField(Particle* particle, std::string extraInfoName, float responseValue, unsigned int i)
182{
183 if (particle->hasExtraInfo(extraInfoName)) {
184 if (particle->getExtraInfo(extraInfoName) != responseValue) {
185 m_existGivenExtraInfo[i] = true;
186 double current = particle->getExtraInfo(extraInfoName);
187 if (m_overwriteExistingExtraInfo[i] == -1) {
188 if (responseValue < current) particle->setExtraInfo(extraInfoName, responseValue);
189 } else if (m_overwriteExistingExtraInfo[i] == 0) {
190 // don't overwrite!
191 } else if (m_overwriteExistingExtraInfo[i] == 1) {
192 if (responseValue > current) particle->setExtraInfo(extraInfoName, responseValue);
193 } else if (m_overwriteExistingExtraInfo[i] == 2) {
194 particle->setExtraInfo(extraInfoName, responseValue);
195 } else {
196 B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo[i] << "'.");
197 }
198 }
199 } else {
200 particle->addExtraInfo(extraInfoName, responseValue);
201 }
202}
203
205 float responseValue, unsigned int i)
206{
207 if (eventExtraInfo->hasExtraInfo(extraInfoName)) {
208 m_existGivenExtraInfo[i] = true;
209 double current = eventExtraInfo->getExtraInfo(extraInfoName);
210 if (m_overwriteExistingExtraInfo[i] == -1) {
211 if (responseValue < current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
212 } else if (m_overwriteExistingExtraInfo[i] == 0) {
213 // don't overwrite!
214 } else if (m_overwriteExistingExtraInfo[i] == 1) {
215 if (responseValue > current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
216 } else if (m_overwriteExistingExtraInfo[i] == 2) {
217 eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
218 } else {
219 B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo[i] << "'.");
220 }
221 } else {
222 eventExtraInfo->addExtraInfo(extraInfoName, responseValue);
223 }
224}
225
227{
228 for (auto& listName : m_listNames) {
229 StoreObjPtr<ParticleList> list(listName);
230 // Calculate target Value for Particles
231 for (unsigned i = 0; i < list->getListSize(); ++i) {
232 Particle* particle = list->getParticle(i);
233 std::vector<std::vector<float>> responseValues = analyse(particle);
234 for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
235 if (m_nClasses[j] == 2) {
236 setExtraInfoField(particle, m_extraInfoNames[j], responseValues[j][0], j);
237 } else if (m_nClasses[j] > 2) {
238 if (responseValues[j].size() != m_nClasses[j]) {
239 B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues[j].size() <<
240 ") does not match the declared number of classes (" << m_nClasses[j] << ").");
241 }
242 for (unsigned int iClass = 0; iClass < m_nClasses[j]; iClass++) {
243 setExtraInfoField(particle, m_extraInfoNames[j] + "_" + std::to_string(iClass), responseValues[j][iClass], j);
244 }
245 } else {
246 B2ERROR("Received a value of " << m_nClasses[j] <<
247 " for the number of classes considered by the MVA Expert. This value should be >=2.");
248 }
249 } //identifiers
250 }
251 } // listnames
252 if (m_listNames.empty()) {
253 StoreObjPtr<EventExtraInfo> eventExtraInfo;
254 if (not eventExtraInfo.isValid())
255 eventExtraInfo.create();
256 std::vector<std::vector<float>> responseValues = analyse(nullptr);
257 for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
258 if (m_nClasses[j] == 2) {
259 setEventExtraInfoField(eventExtraInfo, m_extraInfoNames[j], responseValues[j][0], j);
260 } else if (m_nClasses[j] > 2) {
261 if (responseValues[j].size() != m_nClasses[j]) {
262 B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues[j].size() <<
263 ") does not match the declared number of classes (" << m_nClasses[j] << ").");
264 }
265 for (unsigned int iClass = 0; iClass < m_nClasses[j]; iClass++) {
266 setEventExtraInfoField(eventExtraInfo, m_extraInfoNames[j] + "_" + std::to_string(iClass), responseValues[j][iClass], j);
267 }
268 } else {
269 B2ERROR("Received a value of " << m_nClasses[j] <<
270 " for the number of classes considered by the MVA Expert. This value should be >=2.");
271 }
272 } //identifiers
273 }
274}
275
277{
278 for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
279 m_experts[i].reset();
280 m_datasets[i].reset();
281
282 if (m_existGivenExtraInfo[i]) {
283 if (m_overwriteExistingExtraInfo[i] == -1) {
284 B2WARNING("The extraInfo " << m_extraInfoNames[i] <<
285 " has already been set! It was overwritten by this module if the new value was lower than the previous!");
286 } else if (m_overwriteExistingExtraInfo[i] == 0) {
287 B2WARNING("The extraInfo " << m_extraInfoNames[i] <<
288 " has already been set! The original value was kept and this module did not overwrite it!");
289 } else if (m_overwriteExistingExtraInfo[i] == 1) {
290 B2WARNING("The extraInfo " << m_extraInfoNames[i] <<
291 " has already been set! It was overwritten by this module if the new value was higher than the previous!");
292 } else if (m_overwriteExistingExtraInfo[i] == 2) {
293 B2WARNING("The extraInfo " << m_extraInfoNames[i] << " has already been set! It was overwritten by this module!");
294 }
295 }
296 }
297}
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition: DataStore.h:59
void init_mva(MVA::Weightfile &weightfile, unsigned int i)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::vector< int > m_overwriteExistingExtraInfo
vector of -1/0/1/2: overwrite if lower/ don't overwrite / overwrite if higher/ always overwrite,...
std::vector< std::unique_ptr< MVA::Expert > > m_experts
Vector of pointers to the current MVA Experts.
void setExtraInfoField(Particle *, std::string, float, unsigned int)
Set the extra info field.
std::vector< unsigned int > m_nClasses
number of classes (~outputs) of the MVA Experts.
virtual void initialize() override
Initialize the module.
std::vector< std::unique_ptr< MVA::SingleDataset > > m_datasets
Vector of pointers to the current input datasets.
virtual void event() override
Called for each event.
std::vector< bool > m_existGivenExtraInfo
check if the given extraInfo is already defined.
std::vector< std::vector< const Variable::Manager::Var * > > m_individual_feature_variables
Vector of pointers to the feature variables for each expert.
virtual void terminate() override
Called at the end of the event processing.
double m_signal_fraction_override
Signal Fraction which should be used.
std::vector< std::string > m_identifiers
weight-files
std::vector< std::string > m_listNames
input particle list names
void fillDatasets(Particle *)
Evaluate the variables and fill the Datasets to be used by the experts.
virtual void beginRun() override
Called at the beginning of a new run.
std::map< const Variable::Manager::Var *, float > m_feature_variables
Map containing the values of all needed feature variables.
std::vector< std::vector< float > > analyse(Particle *)
Calculates expert output for given Particle pointer.
void setEventExtraInfoField(StoreObjPtr< EventExtraInfo >, std::string, float, unsigned int)
Set the event extra info field.
std::vector< std::string > m_extraInfoNames
Names under which the SignalProbability is stored in the extraInfo of the Particle object.
std::vector< std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > > m_weightfile_representations
Vector of database pointers to the Database representation of the weightfile.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Class to store reconstructed particles.
Definition: Particle.h:75
void setExtraInfo(const std::string &name, double value)
Sets the user-defined data of given name to the given value.
Definition: Particle.cc:1317
bool hasExtraInfo(const std::string &name) const
Return whether the extra info with the given name is set.
Definition: Particle.cc:1266
void addExtraInfo(const std::string &name, double value)
Sets the user-defined data of given name to the given value.
Definition: Particle.cc:1336
double getExtraInfo(const std::string &name) const
Return given value if set.
Definition: Particle.cc:1289
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:96
Global list of available variables.
Definition: Manager.h:101
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24