Belle II Software  release-08-01-10
MVAExpertModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 
10 #include <mva/modules/MVAExpert/MVAExpertModule.h>
11 
12 #include <analysis/DecayDescriptor/DecayDescriptorParticle.h>
13 
14 #include <analysis/dataobjects/Particle.h>
15 #include <analysis/dataobjects/ParticleList.h>
16 #include <analysis/dataobjects/ParticleExtraInfoMap.h>
17 #include <framework/dataobjects/EventExtraInfo.h>
18 
19 #include <mva/interface/Interface.h>
20 
21 #include <boost/algorithm/string/predicate.hpp>
22 
23 #include <framework/logging/Logger.h>
24 
25 
26 using namespace Belle2;
27 
29 
31 {
32  setDescription("Adds an ExtraInfo to the Particle objects in given ParticleLists which is calcuated by an expert defined by a weightfile.");
34 
35  std::vector<std::string> empty;
36  addParam("listNames", m_listNames,
37  "Particles from these ParticleLists are used as input. If no name is given the expert is applied to every event once, and one can only use variables which accept nullptr as Particle*. Decay descriptor functionality is supported, which allows to run the module on daughter particles, e.g. Lambda0:my_list -> ^p+ pi-. One has to provide full name of the mother particle list and only one selected daughter is supported.",
38  empty);
39  addParam("extraInfoName", m_extraInfoName,
40  "Name under which the output of the expert is stored in the ExtraInfo of the Particle object. If the expert returns multiple values, the index of the value is appended to the name in the form '_0', '_1', ...");
41  addParam("identifier", m_identifier, "The database identifier which is used to load the weights during the training.");
42  addParam("signalFraction", m_signal_fraction_override,
43  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
44  addParam("overwriteExistingExtraInfo", m_overwriteExistingExtraInfo,
45  "-1/0/1/2: Overwrite if lower / don't overwrite / overwrite if higher / always overwrite, in case the extra info with given name already exists",
46  2);
47 }
48 
50 {
51  // All specified ParticleLists are required to exist
52  for (auto& name : m_listNames) {
53  DecayDescriptor dd;
54  bool valid = dd.init(name);
55  if (!valid) {
56  B2ERROR("Decay string " << name << " is invalid.");
57  }
58  const DecayDescriptorParticle* mother = dd.getMother();
59  unsigned int nSelectedDaughters = dd.getSelectionNames().size();
60  if (nSelectedDaughters > 1) {
61  B2ERROR("More than one daughter is selected in the decay string " << name << ".");
62  }
63  StoreObjPtr<ParticleList> list(mother->getFullName());
64  list.isRequired();
65  m_targetListNames.push_back(mother->getFullName());
66  m_decaydescriptors.insert(std::make_pair(mother->getFullName(), dd));
67  }
68 
69  if (m_listNames.empty()) {
71  extraInfo.registerInDataStore();
72  } else {
74  extraInfo.registerInDataStore();
75  }
76 
77  if (not(boost::ends_with(m_identifier, ".root") or boost::ends_with(m_identifier, ".xml"))) {
78  m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
79  MVA::makeSaveForDatabase(m_identifier));
80  }
82 
83  m_existGivenExtraInfo = false;
84 }
85 
87 {
88 
90  if (m_weightfile_representation->hasChanged()) {
91  std::stringstream ss((*m_weightfile_representation)->m_data);
92  auto weightfile = MVA::Weightfile::loadFromStream(ss);
93  init_mva(weightfile);
94  }
95  } else {
97  init_mva(weightfile);
98  }
99 
100 }
101 
103 {
104 
105  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
106  MVA::GeneralOptions general_options;
107  weightfile.getOptions(general_options);
108 
109  // Overwrite signal fraction from training
112 
113  m_expert = supported_interfaces[general_options.m_method]->getExpert();
114  m_expert->load(weightfile);
115 
117  m_feature_variables = manager.getVariables(general_options.m_variables);
118  if (m_feature_variables.size() != general_options.m_variables.size()) {
119  B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
120  }
121 
122  std::vector<float> dummy;
123  dummy.resize(m_feature_variables.size(), 0);
124  m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
125  m_nClasses = general_options.m_nClasses;
126 }
127 
129 {
130  for (unsigned int i = 0; i < m_feature_variables.size(); ++i) {
131  auto var_result = m_feature_variables[i]->function(particle);
132  if (std::holds_alternative<double>(var_result)) {
133  m_dataset->m_input[i] = std::get<double>(var_result);
134  } else if (std::holds_alternative<int>(var_result)) {
135  m_dataset->m_input[i] = std::get<int>(var_result);
136  } else if (std::holds_alternative<bool>(var_result)) {
137  m_dataset->m_input[i] = std::get<bool>(var_result);
138  }
139  }
140 }
141 
142 float MVAExpertModule::analyse(const Particle* particle)
143 {
144  if (not m_expert) {
145  B2ERROR("MVA Expert is not loaded! I will return 0");
146  return 0.0;
147  }
148  fillDataset(particle);
149  return m_expert->apply(*m_dataset)[0];
150 }
151 
152 std::vector<float> MVAExpertModule::analyseMulticlass(const Particle* particle)
153 {
154  if (not m_expert) {
155  B2ERROR("MVA Expert is not loaded! I will return 0");
156  return std::vector<float>(m_nClasses, 0.0);
157  }
158  fillDataset(particle);
159  return m_expert->applyMulticlass(*m_dataset)[0];
160 }
161 
162 void MVAExpertModule::setExtraInfoField(Particle* particle, std::string extraInfoName, float responseValue)
163 {
164  if (particle->hasExtraInfo(extraInfoName)) {
165  if (particle->getExtraInfo(extraInfoName) != responseValue) {
166  m_existGivenExtraInfo = true;
167  double current = particle->getExtraInfo(extraInfoName);
168  if (m_overwriteExistingExtraInfo == -1) {
169  if (responseValue < current) particle->setExtraInfo(extraInfoName, responseValue);
170  } else if (m_overwriteExistingExtraInfo == 0) {
171  // don't overwrite!
172  } else if (m_overwriteExistingExtraInfo == 1) {
173  if (responseValue > current) particle->setExtraInfo(extraInfoName, responseValue);
174  } else if (m_overwriteExistingExtraInfo == 2) {
175  particle->setExtraInfo(extraInfoName, responseValue);
176  } else {
177  B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo << "'.");
178  }
179  }
180  } else {
181  particle->addExtraInfo(extraInfoName, responseValue);
182  }
183 }
184 
185 void MVAExpertModule::setEventExtraInfoField(StoreObjPtr<EventExtraInfo> eventExtraInfo, std::string extraInfoName,
186  float responseValue)
187 {
188  if (eventExtraInfo->hasExtraInfo(extraInfoName)) {
189  m_existGivenExtraInfo = true;
190  double current = eventExtraInfo->getExtraInfo(extraInfoName);
191  if (m_overwriteExistingExtraInfo == -1) {
192  if (responseValue < current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
193  } else if (m_overwriteExistingExtraInfo == 0) {
194  // don't overwrite!
195  } else if (m_overwriteExistingExtraInfo == 1) {
196  if (responseValue > current) eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
197  } else if (m_overwriteExistingExtraInfo == 2) {
198  eventExtraInfo->setExtraInfo(extraInfoName, responseValue);
199  } else {
200  B2FATAL("m_overwriteExistingExtraInfo must be one of {-1,0,1,2}. Received '" << m_overwriteExistingExtraInfo << "'.");
201  }
202  } else {
203  eventExtraInfo->addExtraInfo(extraInfoName, responseValue);
204  }
205 }
206 
208 {
209  for (auto& listName : m_targetListNames) {
210  StoreObjPtr<ParticleList> list(listName);
211  // Calculate target Value for Particles
212  for (unsigned i = 0; i < list->getListSize(); ++i) {
213  auto dd = m_decaydescriptors[listName];
214  unsigned int nSelectedDaughters = dd.getSelectionNames().size();
215  const Particle* particle = (nSelectedDaughters > 0) ? dd.getSelectionParticles(list->getParticle(i))[0] : list->getParticle(i);
216  if (m_nClasses == 2) {
217  float responseValue = analyse(particle);
218  setExtraInfoField(m_particles[particle->getArrayIndex()], m_extraInfoName, responseValue);
219  } else if (m_nClasses > 2) {
220  std::vector<float> responseValues = analyseMulticlass(particle);
221  if (responseValues.size() != m_nClasses) {
222  B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues.size() <<
223  ") does not match the declared number of classes (" << m_nClasses << ").");
224  }
225  for (unsigned int iClass = 0; iClass < m_nClasses; iClass++) {
226  setExtraInfoField(m_particles[particle->getArrayIndex()], m_extraInfoName + "_" + std::to_string(iClass), responseValues[iClass]);
227  }
228  } else {
229  B2ERROR("Received a value of " << m_nClasses <<
230  " for the number of classes considered by the MVA Expert. This value should be >=2.");
231  }
232  }
233  }
234  if (m_listNames.empty()) {
235  StoreObjPtr<EventExtraInfo> eventExtraInfo;
236  if (not eventExtraInfo.isValid())
237  eventExtraInfo.create();
238 
239  if (m_nClasses == 2) {
240  float responseValue = analyse(nullptr);
241  setEventExtraInfoField(eventExtraInfo, m_extraInfoName, responseValue);
242  } else if (m_nClasses > 2) {
243  std::vector<float> responseValues = analyseMulticlass(nullptr);
244  if (responseValues.size() != m_nClasses) {
245  B2ERROR("Size of results returned by MVA Expert applyMulticlass (" << responseValues.size() <<
246  ") does not match the declared number of classes (" << m_nClasses << ").");
247  }
248  for (unsigned int iClass = 0; iClass < m_nClasses; iClass++) {
249  setEventExtraInfoField(eventExtraInfo, m_extraInfoName + "_" + std::to_string(iClass), responseValues[iClass]);
250  }
251  } else {
252  B2ERROR("Received a value of " << m_nClasses <<
253  " for the number of classes considered by the MVA Expert. This value should be >=2.");
254  }
255  }
256 }
257 
259 {
260  m_expert.reset();
261  m_dataset.reset();
262 
263  if (m_existGivenExtraInfo) {
264  if (m_overwriteExistingExtraInfo == -1) {
265  B2WARNING("The extraInfo " << m_extraInfoName <<
266  " has already been set! It was overwritten by this module if the new value was lower than the previous!");
267  } else if (m_overwriteExistingExtraInfo == 0) {
268  B2WARNING("The extraInfo " << m_extraInfoName <<
269  " has already been set! The original value was kept and this module did not overwrite it!");
270  } else if (m_overwriteExistingExtraInfo == 1) {
271  B2WARNING("The extraInfo " << m_extraInfoName <<
272  " has already been set! It was overwritten by this module if the new value was higher than the previous!");
273  } else if (m_overwriteExistingExtraInfo == 2) {
274  B2WARNING("The extraInfo " << m_extraInfoName << " has already been set! It was overwritten by this module!");
275  }
276  }
277 }
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition: DataStore.h:59
Represents a particle in the DecayDescriptor.
The DecayDescriptor stores information about a decay tree or parts of a decay tree.
bool init(const std::string &str)
Initialise the DecayDescriptor from given string.
const DecayDescriptorParticle * getMother() const
return mother.
std::vector< std::string > getSelectionNames()
Return list of human readable names of selected particles.
std::unordered_map< std::string, DecayDescriptor > m_decaydescriptors
Decay descriptor of decays to look for.
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
std::vector< const Variable::Manager::Var * > m_feature_variables
Pointers to the feature variables.
virtual void initialize() override
Initialize the module.
virtual void event() override
Called for each event.
void setEventExtraInfoField(StoreObjPtr< EventExtraInfo >, std::string, float)
Set the event extra info field.
StoreArray< Particle > m_particles
StoreArray of Particles.
virtual void terminate() override
Called at the end of the event processing.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::vector< std::string > m_targetListNames
input particle list names after decay descriptor
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
std::vector< std::string > m_listNames
input particle list names
virtual void beginRun() override
Called at the beginning of a new run.
MVAExpertModule()
Constructor.
float analyse(const Particle *)
Calculates expert output for given Particle pointer.
bool m_existGivenExtraInfo
check if the given extraInfo is already defined.
int m_overwriteExistingExtraInfo
-1/0/1/2: overwrite if lower/ don't overwrite / overwrite if higher/ always overwrite,...
std::vector< float > analyseMulticlass(const Particle *)
Calculates expert output for given Particle pointer.
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::string m_extraInfoName
Name under which the SignalProbability is stored in the extraInfo of the Particle object.
unsigned int m_nClasses
number of classes (~outputs) of the current MVA Expert.
void setExtraInfoField(Particle *, std::string, float)
Set the extra info field.
void fillDataset(const Particle *)
Evaluate the variables and fill the Dataset to be used by the expert.
std::string m_identifier
weight-file
Class to interact with the MVA package, based on class with same name in CDC package.
Definition: MVAExpert.h:33
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
Class to store reconstructed particles.
Definition: Particle.h:75
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:96
Global list of available variables.
Definition: Manager.h:101
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
REG_MODULE(arichBtest)
Register the Module.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
Abstract base class for different kinds of events.