Belle II Software development
MVAMultipleExpertsModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9
10#include <mva/modules/MVAExpert/MVAMultipleExpertsModule.h>
11
12#include <analysis/dataobjects/Particle.h>
13#include <analysis/dataobjects/ParticleList.h>
14#include <analysis/dataobjects/ParticleExtraInfoMap.h>
15
16#include <mva/interface/Interface.h>
17
18#include <framework/logging/Logger.h>
19
20
21using namespace Belle2;
22
23REG_MODULE(MVAMultipleExperts);
24
26{
27 setDescription("Adds ExtraInfos to the Particle objects in given ParticleLists which is calculated by multiple experts defined by the given weightfiles.");
29
30 std::vector<std::string> empty;
31 addParam("listNames", m_listNames,
32 "Particles from these ParticleLists are used as input. If no name is given the experts are applied to every event once, and one can only use variables which accept nullptr as Particle*",
33 empty);
34 addParam("extraInfoNames", m_extraInfoNames,
35 "Names under which the output of the experts is stored in the ExtraInfo of the Particle object.");
36 addParam("identifiers", m_identifiers, "The database identifiers which is used to load the weights during the training.");
37 addParam("signalFraction", m_signal_fraction_override,
38 "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
39 std::vector<int> empty_vec;
40 addParam("overwriteExistingExtraInfo", m_overwriteExistingExtraInfo,
41 "If true, when the given extraInfo has already defined, the old extraInfo value is overwritten. If false, the original value is kept.",
42 empty_vec);
43}
44
46{
47 // All specified ParticleLists are required to exist
48 for (auto& name : m_listNames) {
50 list.isRequired();
51 }
52
53 if (m_listNames.empty()) {
55 extraInfo.isRequired();
56 } else {
58 extraInfo.isRequired();
59 }
60
61 if (m_extraInfoNames.size() != m_identifiers.size()) {
62 B2FATAL("The number of given m_extraInfoNames is not equal to the number of m_identifiers. The output the ith method in m_identifiers is saved as extraInfo under the ith name in m_extraInfoNames! Set also different names for each method!");
63 }
64
66 m_experts.resize(m_identifiers.size());
68 m_datasets.resize(m_identifiers.size());
69 m_nClasses.resize(m_identifiers.size());
70 // if the size of m_overwriteExistingExtraInfo is smaller than that of m_identifiers, 2 will be filled.
72 m_existGivenExtraInfo.resize(m_identifiers.size(), false);
73
74 for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
75 if (not(m_identifiers[i].ends_with(".root") or m_identifiers[i].ends_with(".xml"))) {
76 m_weightfile_representations[i] = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
77 MVA::makeSaveForDatabase(m_identifiers[i]));
78 }
79 }
80
82}
83
85{
86
87 if (!m_weightfile_representations.empty()) {
88 for (unsigned int i = 0; i < m_weightfile_representations.size(); ++i) {
90 if (m_weightfile_representations[i]->hasChanged()) {
91 std::stringstream ss((*m_weightfile_representations[i])->m_data);
92 auto weightfile = MVA::Weightfile::loadFromStream(ss);
93 init_mva(weightfile, i);
94 }
95 } else {
96 auto weightfile = MVA::Weightfile::loadFromFile(m_identifiers[i]);
97 init_mva(weightfile, i);
98 }
99 }
100
101 } else B2FATAL("No m_identifiers given. At least one is needed!");
102}
103
105{
106
107 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
109
110
111 MVA::GeneralOptions general_options;
112 weightfile.getOptions(general_options);
113
114 // Overwrite signal fraction from training
116 weightfile.addSignalFraction(m_signal_fraction_override);
117
118 m_experts[i] = supported_interfaces[general_options.m_method]->getExpert();
119 m_experts[i]->load(weightfile);
120
121
122 m_individual_feature_variables[i] = manager.getVariables(general_options.m_variables);
123 if (m_individual_feature_variables[i].size() != general_options.m_variables.size()) {
124 B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
125 }
126
127 for (auto& iVariable : m_individual_feature_variables[i]) {
128 if (m_feature_variables.find(iVariable) == m_feature_variables.end()) {
129 m_feature_variables.insert(std::pair<const Variable::Manager::Var*, float>(iVariable, 0));
130 }
131 }
132
133 std::vector<float> dummy;
134 dummy.resize(m_individual_feature_variables[i].size(), 0);
135 m_datasets[i] = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
136
137 m_nClasses[i] = general_options.m_nClasses;
138
139}
140
142{
143 for (auto const& iVariable : m_feature_variables) {
144 if (iVariable.first->variabletype == Variable::Manager::VariableDataType::c_double) {
145 m_feature_variables[iVariable.first] = std::get<double>(iVariable.first->function(particle));
146 } else if (iVariable.first->variabletype == Variable::Manager::VariableDataType::c_int) {
147 m_feature_variables[iVariable.first] = std::get<int>(iVariable.first->function(particle));
148 } else if (iVariable.first->variabletype == Variable::Manager::VariableDataType::c_bool) {
149 m_feature_variables[iVariable.first] = std::get<bool>(iVariable.first->function(particle));
150 }
151 }
152
153 for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
154 for (unsigned int j = 0; j < m_individual_feature_variables[i].size(); ++j) {
156 }
157 }
158}
159
160std::vector<std::vector<float>> MVAMultipleExpertsModule::analyse(Particle* particle)
161{
162 std::vector<std::vector<float>> responseValues;
163 fillDatasets(particle);
164
165 for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
166 if (m_nClasses[i] == 2) {
167 responseValues.push_back({m_experts[i]->apply(*m_datasets[i])[0],});
168 } else if (m_nClasses[i] > 2) {
169 responseValues.push_back(m_experts[i]->applyMulticlass(*m_datasets[i])[0]);
170 } else {
171 B2ERROR("Received a value of " << m_nClasses[i] <<
172 " for the number of classes considered by the MVA Expert. This value should be >=2.");
173 }
174 }
175 return responseValues;
176}
177
178void MVAMultipleExpertsModule::setExtraInfoField(Particle* particle, std::string extraInfoName, float responseValue, unsigned int i)
179{
180 if (particle->hasExtraInfo(extraInfoName)) {
181 if (particle->getExtraInfo(extraInfoName) != responseValue) {
182 m_existGivenExtraInfo[i] = true;
183 double current = particle->getExtraInfo(extraInfoName);
184 if (m_overwriteExistingExtraInfo[i] == -1) {
185 if (responseValue < current) particle->setExtraInfo(extraInfoName, responseValue);
186 } else if (m_overwriteExistingExtraInfo[i] == 0) {
187 // don't overwrite!
188 } else if (m_overwriteExistingExtraInfo[i] == 1) {
189 if (responseValue > current) particle->setExtraInfo(extraInfoName, responseValue);
190 } else if (m_overwriteExistingExtraInfo[i] == 2) {
191 particle->setExtraInfo(extraInfoName, responseValue);
192 } else {
193 B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo[i] << "'.");
194 }
195 }
196 } else {
197 particle->addExtraInfo(extraInfoName, responseValue);
198 }
199}
200
202 float responseValue, unsigned int i)
203{
204 if (eventExtraInfo->hasExtraInfo(extraInfoName)) {
205 m_existGivenExtraInfo[i] = true;
206 double current = eventExtraInfo->getExtraInfo(extraInfoName);
207 if (m_overwriteExistingExtraInfo[i] == -1) {
208 if (responseValue < current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
209 } else if (m_overwriteExistingExtraInfo[i] == 0) {
210 // don't overwrite!
211 } else if (m_overwriteExistingExtraInfo[i] == 1) {
212 if (responseValue > current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
213 } else if (m_overwriteExistingExtraInfo[i] == 2) {
214 eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
215 } else {
216 B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo[i] << "'.");
217 }
218 } else {
219 eventExtraInfo->addExtraInfo(extraInfoName, responseValue);
220 }
221}
222
224{
225 for (auto& listName : m_listNames) {
226 StoreObjPtr<ParticleList> list(listName);
227 // Calculate target Value for Particles
228 for (unsigned i = 0; i < list->getListSize(); ++i) {
229 Particle* particle = list->getParticle(i);
230 std::vector<std::vector<float>> responseValues = analyse(particle);
231 for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
232 if (m_nClasses[j] == 2) {
233 setExtraInfoField(particle, m_extraInfoNames[j], responseValues[j][0], j);
234 } else if (m_nClasses[j] > 2) {
235 if (responseValues[j].size() != m_nClasses[j]) {
236 B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues[j].size() <<
237 ") does not match the declared number of classes (" << m_nClasses[j] << ").");
238 }
239 for (unsigned int iClass = 0; iClass < m_nClasses[j]; iClass++) {
240 setExtraInfoField(particle, m_extraInfoNames[j] + "_" + std::to_string(iClass), responseValues[j][iClass], j);
241 }
242 } else {
243 B2ERROR("Received a value of " << m_nClasses[j] <<
244 " for the number of classes considered by the MVA Expert. This value should be >=2.");
245 }
246 } //identifiers
247 }
248 } // listnames
249 if (m_listNames.empty()) {
250 StoreObjPtr<EventExtraInfo> eventExtraInfo;
251 if (not eventExtraInfo.isValid())
252 eventExtraInfo.create();
253 std::vector<std::vector<float>> responseValues = analyse(nullptr);
254 for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
255 if (m_nClasses[j] == 2) {
256 setEventExtraInfoField(eventExtraInfo, m_extraInfoNames[j], responseValues[j][0], j);
257 } else if (m_nClasses[j] > 2) {
258 if (responseValues[j].size() != m_nClasses[j]) {
259 B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues[j].size() <<
260 ") does not match the declared number of classes (" << m_nClasses[j] << ").");
261 }
262 for (unsigned int iClass = 0; iClass < m_nClasses[j]; iClass++) {
263 setEventExtraInfoField(eventExtraInfo, m_extraInfoNames[j] + "_" + std::to_string(iClass), responseValues[j][iClass], j);
264 }
265 } else {
266 B2ERROR("Received a value of " << m_nClasses[j] <<
267 " for the number of classes considered by the MVA Expert. This value should be >=2.");
268 }
269 } //identifiers
270 }
271}
272
274{
275 for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
276 m_experts[i].reset();
277 m_datasets[i].reset();
278
279 if (m_existGivenExtraInfo[i]) {
280 if (m_overwriteExistingExtraInfo[i] == -1) {
281 B2WARNING("The extraInfo " << m_extraInfoNames[i] <<
282 " has already been set! It was overwritten by this module if the new value was lower than the previous!");
283 } else if (m_overwriteExistingExtraInfo[i] == 0) {
284 B2WARNING("The extraInfo " << m_extraInfoNames[i] <<
285 " has already been set! The original value was kept and this module did not overwrite it!");
286 } else if (m_overwriteExistingExtraInfo[i] == 1) {
287 B2WARNING("The extraInfo " << m_extraInfoNames[i] <<
288 " has already been set! It was overwritten by this module if the new value was higher than the previous!");
289 } else if (m_overwriteExistingExtraInfo[i] == 2) {
290 B2WARNING("The extraInfo " << m_extraInfoNames[i] << " has already been set! It was overwritten by this module!");
291 }
292 }
293 }
294}
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition DataStore.h:59
void init_mva(MVA::Weightfile &weightfile, unsigned int i)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::vector< int > m_overwriteExistingExtraInfo
vector of -1/0/1/2: overwrite if lower/ don't overwrite / overwrite if higher/ always overwrite,...
std::vector< std::unique_ptr< MVA::Expert > > m_experts
Vector of pointers to the current MVA Experts.
void setExtraInfoField(Particle *, std::string, float, unsigned int)
Set the extra info field.
std::vector< unsigned int > m_nClasses
number of classes (~outputs) of the MVA Experts.
virtual void initialize() override
Initialize the module.
std::vector< std::unique_ptr< MVA::SingleDataset > > m_datasets
Vector of pointers to the current input datasets.
virtual void event() override
Called for each event.
std::vector< bool > m_existGivenExtraInfo
check if the given extraInfo is already defined.
std::vector< std::vector< const Variable::Manager::Var * > > m_individual_feature_variables
Vector of pointers to the feature variables for each expert.
virtual void terminate() override
Called at the end of the event processing.
double m_signal_fraction_override
Signal Fraction which should be used.
std::vector< std::string > m_identifiers
weight-files
std::vector< std::string > m_listNames
input particle list names
void fillDatasets(Particle *)
Evaluate the variables and fill the Datasets to be used by the experts.
virtual void beginRun() override
Called at the beginning of a new run.
std::map< const Variable::Manager::Var *, float > m_feature_variables
Map containing the values of all needed feature variables.
std::vector< std::vector< float > > analyse(Particle *)
Calculates expert output for given Particle pointer.
void setEventExtraInfoField(StoreObjPtr< EventExtraInfo >, std::string, float, unsigned int)
Set the event extra info field.
std::vector< std::string > m_extraInfoNames
Names under which the SignalProbability is stored in the extraInfo of the Particle object.
std::vector< std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > > m_weightfile_representations
Vector of database pointers to the Database representation of the weightfile.
static void initSupportedInterfaces()
Static function which initializes all supported interfaces, has to be called once before getSupported...
Definition Interface.cc:46
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition Interface.h:53
General options which are shared by all MVA trainings.
Definition Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
Module()
Constructor.
Definition Module.cc:30
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
Class to store reconstructed particles.
Definition Particle.h:76
Type-safe access to single objects in the data store.
Definition StoreObjPtr.h:96
Global list of available variables.
Definition Manager.h:100
static Manager & Instance()
get singleton instance.
Definition Manager.cc:26
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
Abstract base class for different kinds of events.