Belle II Software  light-2205-abys
MVAMultipleExpertsModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 
10 #include <mva/modules/MVAExpert/MVAMultipleExpertsModule.h>
11 
12 #include <analysis/dataobjects/Particle.h>
13 #include <analysis/dataobjects/ParticleList.h>
14 #include <analysis/dataobjects/ParticleExtraInfoMap.h>
15 #include <analysis/dataobjects/EventExtraInfo.h>
16 
17 #include <mva/interface/Interface.h>
18 
19 #include <boost/algorithm/string/predicate.hpp>
20 #include <memory>
21 
22 #include <framework/logging/Logger.h>
23 
24 
25 using namespace Belle2;
26 
27 REG_MODULE(MVAMultipleExperts);
28 
30 {
31  setDescription("Adds ExtraInfos to the Particle objects in given ParticleLists which is calcuated by multiple experts defined by the given weightfiles.");
33 
34  std::vector<std::string> empty;
35  addParam("listNames", m_listNames,
36  "Particles from these ParticleLists are used as input. If no name is given the experts are applied to every event once, and one can only use variables which accept nullptr as Particle*",
37  empty);
38  addParam("extraInfoNames", m_extraInfoNames,
39  "Names under which the output of the experts is stored in the ExtraInfo of the Particle object.");
40  addParam("identifiers", m_identifiers, "The database identifiers which is used to load the weights during the training.");
41  addParam("signalFraction", m_signal_fraction_override,
42  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
43  std::vector<bool> empty_bool;
44  addParam("overwriteExistingExtraInfo", m_overwriteExistingExtraInfo,
45  "If true, when the given extraInfo has already defined, the old extraInfo value is overwritten. If false, the original value is kept.",
46  empty_bool);
47 }
48 
50 {
51  // All specified ParticleLists are required to exist
52  for (auto& name : m_listNames) {
53  StoreObjPtr<ParticleList> list(name);
54  list.isRequired();
55  }
56 
57  if (m_listNames.empty()) {
59  extraInfo.registerInDataStore();
60  } else {
62  extraInfo.registerInDataStore();
63  }
64 
65  if (m_extraInfoNames.size() != m_identifiers.size()) {
66  B2FATAL("The number of given m_extraInfoNames is not equal to the number of m_identifiers. The output the ith method in m_identifiers is saved as extraInfo under the ith name in m_extraInfoNames! Set also different names for each method!");
67  }
68 
70  m_experts.resize(m_identifiers.size());
72  m_datasets.resize(m_identifiers.size());
73  // if the size of m_overwriteExistingExtraInfo is smaller than that of m_identifiers, true will be filled.
74  m_overwriteExistingExtraInfo.resize(m_identifiers.size(), true);
75  m_existGivenExtraInfo.resize(m_identifiers.size(), false);
76 
77  for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
78  if (not(boost::ends_with(m_identifiers[i], ".root") or boost::ends_with(m_identifiers[i], ".xml"))) {
79  m_weightfile_representations[i] = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
80  MVA::makeSaveForDatabase(m_identifiers[i]));
81  }
82  }
83 
85 }
86 
88 {
89 
90  if (!m_weightfile_representations.empty()) {
91  for (unsigned int i = 0; i < m_weightfile_representations.size(); ++i) {
93  if (m_weightfile_representations[i]->hasChanged()) {
94  std::stringstream ss((*m_weightfile_representations[i])->m_data);
95  auto weightfile = MVA::Weightfile::loadFromStream(ss);
96  init_mva(weightfile, i);
97  }
98  } else {
99  auto weightfile = MVA::Weightfile::loadFromFile(m_identifiers[i]);
100  init_mva(weightfile, i);
101  }
102  }
103 
104  } else B2FATAL("No m_identifiers given. At least one is needed!");
105 }
106 
107 void MVAMultipleExpertsModule::init_mva(MVA::Weightfile& weightfile, unsigned int i)
108 {
109 
110  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
112 
113 
114  MVA::GeneralOptions general_options;
115  weightfile.getOptions(general_options);
116 
117  // Overwrite signal fraction from training
120 
121  m_experts[i] = supported_interfaces[general_options.m_method]->getExpert();
122  m_experts[i]->load(weightfile);
123 
124 
125  m_individual_feature_variables[i] = manager.getVariables(general_options.m_variables);
126  if (m_individual_feature_variables[i].size() != general_options.m_variables.size()) {
127  B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
128  }
129 
130  for (auto& iVariable : m_individual_feature_variables[i]) {
131  if (m_feature_variables.find(iVariable) == m_feature_variables.end()) {
132  m_feature_variables.insert(std::pair<const Variable::Manager::Var*, float>(iVariable, 0));
133  }
134  }
135 
136  std::vector<float> dummy;
137  dummy.resize(m_individual_feature_variables[i].size(), 0);
138  m_datasets[i] = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
139 
140 }
141 
142 std::vector<float> MVAMultipleExpertsModule::analyse(Particle* particle)
143 {
144 
145  std::vector<float> targetValues;
146  targetValues.resize(m_identifiers.size());
147  for (auto const& iVariable : m_feature_variables) {
148  if (iVariable.first->variabletype == Variable::Manager::VariableDataType::c_double) {
149  m_feature_variables[iVariable.first] = std::get<double>(iVariable.first->function(particle));
150  } else if (iVariable.first->variabletype == Variable::Manager::VariableDataType::c_int) {
151  m_feature_variables[iVariable.first] = std::get<int>(iVariable.first->function(particle));
152  } else if (iVariable.first->variabletype == Variable::Manager::VariableDataType::c_bool) {
153  m_feature_variables[iVariable.first] = std::get<bool>(iVariable.first->function(particle));
154  }
155  }
156 
157  for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
158  for (unsigned int j = 0; j < m_individual_feature_variables[i].size(); ++j) {
160  }
161  targetValues[i] = m_experts[i]->apply(*m_datasets[i])[0];
162  }
163 
164  return targetValues;
165 }
166 
167 
169 {
170  for (auto& listName : m_listNames) {
171  StoreObjPtr<ParticleList> list(listName);
172  // Calculate target Value for Particles
173  for (unsigned i = 0; i < list->getListSize(); ++i) {
174  Particle* particle = list->getParticle(i);
175  std::vector<float> targetValues = analyse(particle);
176 
177  for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
178  if (particle->hasExtraInfo(m_extraInfoNames[j])) {
179  if (particle->getExtraInfo(m_extraInfoNames[j]) != targetValues[j]) {
180  m_existGivenExtraInfo[j] = true;
182  particle->setExtraInfo(m_extraInfoNames[j], targetValues[j]);
183  }
184  } else {
185  particle->addExtraInfo(m_extraInfoNames[j], targetValues[j]);
186  }
187  }
188  }
189  }
190  if (m_listNames.empty()) {
191  StoreObjPtr<EventExtraInfo> eventExtraInfo;
192  if (not eventExtraInfo.isValid())
193  eventExtraInfo.create();
194  std::vector<float> targetValues = analyse(nullptr);
195  for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
196  if (eventExtraInfo->hasExtraInfo(m_extraInfoNames[j])) {
197  m_existGivenExtraInfo[j] = true;
199  eventExtraInfo->setExtraInfo(m_extraInfoNames[j], targetValues[j]);
200  } else {
201  eventExtraInfo->addExtraInfo(m_extraInfoNames[j], targetValues[j]);
202  }
203  }
204  }
205 }
206 
208 {
209  for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
210  m_experts[i].reset();
211  m_datasets[i].reset();
212 
213  if (m_existGivenExtraInfo[i]) {
215  B2WARNING("The extraInfo " << m_extraInfoNames[i] << " has already been set! It was overwritten by this module!");
216  else
217  B2WARNING("The extraInfo " << m_extraInfoNames[i] << " has already been set! "
218  << "The original value was kept and this module did not overwrite it!");
219  }
220  }
221 
222 }
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition: DataStore.h:59
void init_mva(MVA::Weightfile &weightfile, unsigned int i)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::vector< std::unique_ptr< MVA::Expert > > m_experts
Vector of pointers to the current MVA Experts.
virtual void initialize() override
Initialize the module.
std::vector< std::unique_ptr< MVA::SingleDataset > > m_datasets
Vector of pointers to the current input datasets.
std::vector< float > analyse(Particle *)
Calculates expert output for given Particle pointer.
virtual void event() override
Called for each event.
std::vector< bool > m_existGivenExtraInfo
check if the given extraInfo is already defined.
std::vector< std::vector< const Variable::Manager::Var * > > m_individual_feature_variables
Vector of pointers to the feature variables for each expert.
virtual void terminate() override
Called at the end of the event processing.
double m_signal_fraction_override
Signal Fraction which should be used.
std::vector< std::string > m_identifiers
weight-files
std::vector< std::string > m_listNames
input particle list names
virtual void beginRun() override
Called at the beginning of a new run.
std::map< const Variable::Manager::Var *, float > m_feature_variables
Map containing the values of all needed feature variables.
std::vector< std::string > m_extraInfoNames
Names under which the SignalProbability is stored in the extraInfo of the Particle object.
std::vector< std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > > m_weightfile_representations
Vector of database pointers to the Database representation of the weightfile.
std::vector< bool > m_overwriteExistingExtraInfo
if true, when the given extraInfo is already defined, the old extraInfo value is overwritten
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:250
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:66
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:205
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:94
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Class to store reconstructed particles.
Definition: Particle.h:74
void setExtraInfo(const std::string &name, double value)
Sets the user-defined data of given name to the given value.
Definition: Particle.cc:1306
bool hasExtraInfo(const std::string &name) const
Return whether the extra info with the given name is set.
Definition: Particle.cc:1255
void addExtraInfo(const std::string &name, double value)
Sets the user-defined data of given name to the given value.
Definition: Particle.cc:1325
double getExtraInfo(const std::string &name) const
Return given value if set.
Definition: Particle.cc:1278
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:95
Global list of available variables.
Definition: Manager.h:101
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
REG_MODULE(B2BIIConvertBeamParams)
Register the module.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:23