Belle II Software  light-2212-foldex
MVAExpertModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 
10 #include <mva/modules/MVAExpert/MVAExpertModule.h>
11 
12 #include <analysis/dataobjects/Particle.h>
13 #include <analysis/dataobjects/ParticleList.h>
14 #include <analysis/dataobjects/ParticleExtraInfoMap.h>
15 #include <framework/dataobjects/EventExtraInfo.h>
16 
17 #include <mva/interface/Interface.h>
18 
19 #include <boost/algorithm/string/predicate.hpp>
20 
21 #include <framework/logging/Logger.h>
22 
23 
24 using namespace Belle2;
25 
26 REG_MODULE(MVAExpert);
27 
29 {
30  setDescription("Adds an ExtraInfo to the Particle objects in given ParticleLists which is calcuated by an expert defined by a weightfile.");
32 
33  std::vector<std::string> empty;
34  addParam("listNames", m_listNames,
35  "Particles from these ParticleLists are used as input. If no name is given the expert is applied to every event once, and one can only use variables which accept nullptr as Particle*",
36  empty);
37  addParam("extraInfoName", m_extraInfoName,
38  "Name under which the output of the expert is stored in the ExtraInfo of the Particle object. If the expert returns multiple values, the index of the value is appended to the name in the form '_0', '_1', ...");
39  addParam("identifier", m_identifier, "The database identifier which is used to load the weights during the training.");
40  addParam("signalFraction", m_signal_fraction_override,
41  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
42  addParam("overwriteExistingExtraInfo", m_overwriteExistingExtraInfo,
43  "-1/0/1/2: Overwrite if lower / don't overwrite / overwrite if higher / always overwrite, in case the extra info with given name already exists",
44  2);
45 }
46 
48 {
49  // All specified ParticleLists are required to exist
50  for (auto& name : m_listNames) {
51  StoreObjPtr<ParticleList> list(name);
52  list.isRequired();
53  }
54 
55  if (m_listNames.empty()) {
57  extraInfo.registerInDataStore();
58  } else {
60  extraInfo.registerInDataStore();
61  }
62 
63  if (not(boost::ends_with(m_identifier, ".root") or boost::ends_with(m_identifier, ".xml"))) {
64  m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
65  MVA::makeSaveForDatabase(m_identifier));
66  }
68 
69  m_existGivenExtraInfo = false;
70 }
71 
73 {
74 
76  if (m_weightfile_representation->hasChanged()) {
77  std::stringstream ss((*m_weightfile_representation)->m_data);
78  auto weightfile = MVA::Weightfile::loadFromStream(ss);
79  init_mva(weightfile);
80  }
81  } else {
83  init_mva(weightfile);
84  }
85 
86 }
87 
89 {
90 
91  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
92  MVA::GeneralOptions general_options;
93  weightfile.getOptions(general_options);
94 
95  // Overwrite signal fraction from training
98 
99  m_expert = supported_interfaces[general_options.m_method]->getExpert();
100  m_expert->load(weightfile);
101 
103  m_feature_variables = manager.getVariables(general_options.m_variables);
104  if (m_feature_variables.size() != general_options.m_variables.size()) {
105  B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
106  }
107 
108  std::vector<float> dummy;
109  dummy.resize(m_feature_variables.size(), 0);
110  m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
111  m_nClasses = general_options.m_nClasses;
112 }
113 
115 {
116  for (unsigned int i = 0; i < m_feature_variables.size(); ++i) {
117  auto var_result = m_feature_variables[i]->function(particle);
118  if (std::holds_alternative<double>(var_result)) {
119  m_dataset->m_input[i] = std::get<double>(var_result);
120  } else if (std::holds_alternative<int>(var_result)) {
121  m_dataset->m_input[i] = std::get<int>(var_result);
122  } else if (std::holds_alternative<bool>(var_result)) {
123  m_dataset->m_input[i] = std::get<bool>(var_result);
124  }
125  }
126 }
127 
129 {
130  if (not m_expert) {
131  B2ERROR("MVA Expert is not loaded! I will return 0");
132  return 0.0;
133  }
134  fillDataset(particle);
135  return m_expert->apply(*m_dataset)[0];
136 }
137 
138 std::vector<float> MVAExpertModule::analyseMulticlass(Particle* particle)
139 {
140  if (not m_expert) {
141  B2ERROR("MVA Expert is not loaded! I will return 0");
142  return std::vector<float>(m_nClasses, 0.0);
143  }
144  fillDataset(particle);
145  return m_expert->applyMulticlass(*m_dataset)[0];
146 }
147 
148 void MVAExpertModule::setExtraInfoField(Particle* particle, std::string extraInfoName, float responseValue)
149 {
150  if (particle->hasExtraInfo(extraInfoName)) {
151  if (particle->getExtraInfo(extraInfoName) != responseValue) {
152  m_existGivenExtraInfo = true;
153  double current = particle->getExtraInfo(extraInfoName);
154  if (m_overwriteExistingExtraInfo == -1) {
155  if (responseValue < current) particle->setExtraInfo(extraInfoName, responseValue);
156  } else if (m_overwriteExistingExtraInfo == 0) {
157  // don't overwrite!
158  } else if (m_overwriteExistingExtraInfo == 1) {
159  if (responseValue > current) particle->setExtraInfo(extraInfoName, responseValue);
160  } else if (m_overwriteExistingExtraInfo == 2) {
161  particle->setExtraInfo(extraInfoName, responseValue);
162  } else {
163  B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo << "'.");
164  }
165  }
166  } else {
167  particle->addExtraInfo(extraInfoName, responseValue);
168  }
169 }
170 
171 void MVAExpertModule::setEventExtraInfoField(StoreObjPtr<EventExtraInfo> eventExtraInfo, std::string extraInfoName,
172  float responseValue)
173 {
174  if (eventExtraInfo->hasExtraInfo(extraInfoName)) {
175  m_existGivenExtraInfo = true;
176  double current = eventExtraInfo->getExtraInfo(extraInfoName);
177  if (m_overwriteExistingExtraInfo == -1) {
178  if (responseValue < current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
179  } else if (m_overwriteExistingExtraInfo == 0) {
180  // don't overwrite!
181  } else if (m_overwriteExistingExtraInfo == 1) {
182  if (responseValue > current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
183  } else if (m_overwriteExistingExtraInfo == 2) {
184  eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
185  } else {
186  B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo << "'.");
187  }
188  } else {
189  eventExtraInfo->addExtraInfo(extraInfoName, responseValue);
190  }
191 }
192 
194 {
195  for (auto& listName : m_listNames) {
196  StoreObjPtr<ParticleList> list(listName);
197  // Calculate target Value for Particles
198  for (unsigned i = 0; i < list->getListSize(); ++i) {
199  Particle* particle = list->getParticle(i);
200  if (m_nClasses == 2) {
201  float responseValue = analyse(particle);
202  setExtraInfoField(particle, m_extraInfoName, responseValue);
203  } else if (m_nClasses > 2) {
204  std::vector<float> responseValues = analyseMulticlass(particle);
205  if (responseValues.size() != m_nClasses) {
206  B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues.size() <<
207  ") does not match the declared number of classes (" << m_nClasses << ").");
208  }
209  for (unsigned int iClass = 0; iClass < m_nClasses; iClass++) {
210  setExtraInfoField(particle, m_extraInfoName + "_" + std::to_string(iClass), responseValues[iClass]);
211  }
212  } else {
213  B2ERROR("Received a value of " << m_nClasses <<
214  " for the number of classes considered by the MVA Expert. This value should be >=2.");
215  }
216  }
217  }
218  if (m_listNames.empty()) {
219  StoreObjPtr<EventExtraInfo> eventExtraInfo;
220  if (not eventExtraInfo.isValid())
221  eventExtraInfo.create();
222 
223  if (m_nClasses == 2) {
224  float responseValue = analyse(nullptr);
225  setEventExtraInfoField(eventExtraInfo, m_extraInfoName, responseValue);
226  } else if (m_nClasses > 2) {
227  std::vector<float> responseValues = analyseMulticlass(nullptr);
228  if (responseValues.size() != m_nClasses) {
229  B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues.size() <<
230  ") does not match the declared number of classes (" << m_nClasses << ").");
231  }
232  for (unsigned int iClass = 0; iClass < m_nClasses; iClass++) {
233  setEventExtraInfoField(eventExtraInfo, m_extraInfoName + "_" + std::to_string(iClass), responseValues[iClass]);
234  }
235  } else {
236  B2ERROR("Received a value of " << m_nClasses <<
237  " for the number of classes considered by the MVA Expert. This value should be >=2.");
238  }
239  }
240 }
241 
243 {
244  m_expert.reset();
245  m_dataset.reset();
246 
247  if (m_existGivenExtraInfo) {
248  if (m_overwriteExistingExtraInfo == -1) {
249  B2WARNING("The extraInfo " << m_extraInfoName <<
250  " has already been set! It was overwritten by this module if the new value was lower than the previous!");
251  } else if (m_overwriteExistingExtraInfo == 0) {
252  B2WARNING("The extraInfo " << m_extraInfoName <<
253  " has already been set! The original value was kept and this module did not overwrite it!");
254  } else if (m_overwriteExistingExtraInfo == 1) {
255  B2WARNING("The extraInfo " << m_extraInfoName <<
256  " has already been set! It was overwritten by this module if the new value was higher than the previous!");
257  } else if (m_overwriteExistingExtraInfo == 2) {
258  B2WARNING("The extraInfo " << m_extraInfoName << " has already been set! It was overwritten by this module!");
259  }
260  }
261 }
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition: DataStore.h:59
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
std::vector< const Variable::Manager::Var * > m_feature_variables
Pointers to the feature variables.
virtual void initialize() override
Initialize the module.
void fillDataset(Particle *)
Evaluate the variables and fill the Dataset to be used by the expert.
virtual void event() override
Called for each event.
void setEventExtraInfoField(StoreObjPtr< EventExtraInfo >, std::string, float)
Set the event extra info field.
virtual void terminate() override
Called at the end of the event processing.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
std::vector< std::string > m_listNames
input particle list names
virtual void beginRun() override
Called at the beginning of a new run.
MVAExpertModule()
Constructor.
bool m_existGivenExtraInfo
check if the given extraInfo is already defined.
int m_overwriteExistingExtraInfo
-1/0/1/2: overwrite if lower/ don't overwrite / overwrite if higher/ always overwrite,...
float analyse(Particle *)
Calculates expert output for given Particle pointer.
std::vector< float > analyseMulticlass(Particle *)
Calculates expert output for given Particle pointer.
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::string m_extraInfoName
Name under which the SignalProbability is stored in the extraInfo of the Particle object.
unsigned int m_nClasses
number of classes (~outputs) of the current MVA Expert.
void setExtraInfoField(Particle *, std::string, float)
Set the extra info field.
std::string m_identifier
weight-file
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:250
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:66
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:205
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:94
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Class to store reconstructed particles.
Definition: Particle.h:74
void setExtraInfo(const std::string &name, double value)
Sets the user-defined data of given name to the given value.
Definition: Particle.cc:1344
bool hasExtraInfo(const std::string &name) const
Return whether the extra info with the given name is set.
Definition: Particle.cc:1293
void addExtraInfo(const std::string &name, double value)
Sets the user-defined data of given name to the given value.
Definition: Particle.cc:1363
double getExtraInfo(const std::string &name) const
Return given value if set.
Definition: Particle.cc:1316
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:95
Global list of available variables.
Definition: Manager.h:101
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
REG_MODULE(B2BIIConvertBeamParams)
Register the module.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:23