Belle II Software  release-06-00-14
MVAMultipleExpertsModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 
10 #include <mva/modules/MVAExpert/MVAMultipleExpertsModule.h>
11 
12 #include <analysis/dataobjects/Particle.h>
13 #include <analysis/dataobjects/ParticleList.h>
14 #include <analysis/dataobjects/ParticleExtraInfoMap.h>
15 #include <analysis/dataobjects/EventExtraInfo.h>
16 
17 #include <mva/interface/Interface.h>
18 
19 #include <boost/algorithm/string/predicate.hpp>
20 #include <memory>
21 
22 #include <framework/logging/Logger.h>
23 
24 
25 namespace Belle2 {
31  REG_MODULE(MVAMultipleExperts)
32 
34  {
35  setDescription("Adds ExtraInfos to the Particle objects in given ParticleLists which is calcuated by multiple experts defined by the given weightfiles.");
36  setPropertyFlags(c_ParallelProcessingCertified);
37 
38  std::vector<std::string> empty;
39  addParam("listNames", m_listNames,
40  "Particles from these ParticleLists are used as input. If no name is given the experts are applied to every event once, and one can only use variables which accept nullptr as Particle*",
41  empty);
42  addParam("extraInfoNames", m_extraInfoNames,
43  "Names under which the output of the experts is stored in the ExtraInfo of the Particle object.");
44  addParam("identifiers", m_identifiers, "The database identifiers which is used to load the weights during the training.");
45  addParam("signalFraction", m_signal_fraction_override,
46  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
47  }
48 
50  {
51  // All specified ParticleLists are required to exist
52  for (auto& name : m_listNames) {
53  StoreObjPtr<ParticleList> list(name);
54  list.isRequired();
55  }
56 
57  if (m_listNames.empty()) {
59  extraInfo.registerInDataStore();
60  } else {
62  extraInfo.registerInDataStore();
63  }
64 
65  if (m_extraInfoNames.size() != m_identifiers.size()) {
66  B2FATAL("The number of given m_extraInfoNames is not equal to the number of m_identifiers. The output the ith method in m_identifiers is saved as extraInfo under the ith name in m_extraInfoNames! Set also different names for each method!");
67  }
68 
70  m_experts.resize(m_identifiers.size());
72  m_datasets.resize(m_identifiers.size());
73 
74  for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
75  if (not(boost::ends_with(m_identifiers[i], ".root") or boost::ends_with(m_identifiers[i], ".xml"))) {
76  m_weightfile_representations[i] = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
77  MVA::makeSaveForDatabase(m_identifiers[i]));
78  }
79  }
80 
82 
83  }
84 
86  {
87 
88  if (!m_weightfile_representations.empty()) {
89  for (unsigned int i = 0; i < m_weightfile_representations.size(); ++i) {
91  if (m_weightfile_representations[i]->hasChanged()) {
92  std::stringstream ss((*m_weightfile_representations[i])->m_data);
93  auto weightfile = MVA::Weightfile::loadFromStream(ss);
94  init_mva(weightfile, i);
95  }
96  } else {
97  auto weightfile = MVA::Weightfile::loadFromFile(m_identifiers[i]);
98  init_mva(weightfile, i);
99  }
100  }
101 
102  } else B2FATAL("No m_identifiers given. At least one is needed!");
103  }
104 
105  void MVAMultipleExpertsModule::init_mva(MVA::Weightfile& weightfile, unsigned int i)
106  {
107 
108  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
110 
111 
112  MVA::GeneralOptions general_options;
113  weightfile.getOptions(general_options);
114 
115  // Overwrite signal fraction from training
118 
119  m_experts[i] = supported_interfaces[general_options.m_method]->getExpert();
120  m_experts[i]->load(weightfile);
121 
122 
123  m_individual_feature_variables[i] = manager.getVariables(general_options.m_variables);
124  if (m_individual_feature_variables[i].size() != general_options.m_variables.size()) {
125  B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
126  }
127 
128  for (auto& iVariable : m_individual_feature_variables[i]) {
129  if (m_feature_variables.find(iVariable) == m_feature_variables.end()) {
130  m_feature_variables.insert(std::pair<const Variable::Manager::Var*, float>(iVariable, 0));
131  }
132  }
133 
134  std::vector<float> dummy;
135  dummy.resize(m_individual_feature_variables[i].size(), 0);
136  m_datasets[i] = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
137 
138  }
139 
140  std::vector<float> MVAMultipleExpertsModule::analyse(Particle* particle)
141  {
142 
143  std::vector<float> targetValues;
144  targetValues.resize(m_identifiers.size());
145  for (auto const& iVariable : m_feature_variables) {
146  m_feature_variables[iVariable.first] = iVariable.first ->function(particle);
147  }
148 
149  for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
150  for (unsigned int j = 0; j < m_individual_feature_variables[i].size(); ++j) {
152  }
153  targetValues[i] = m_experts[i]->apply(*m_datasets[i])[0];
154  }
155 
156  return targetValues;
157  }
158 
159 
161  {
162  for (auto& listName : m_listNames) {
163  StoreObjPtr<ParticleList> list(listName);
164  // Calculate target Value for Particles
165  for (unsigned i = 0; i < list->getListSize(); ++i) {
166  Particle* particle = list->getParticle(i);
167  std::vector<float> targetValues = analyse(particle);
168 
169  for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
170  if (particle->hasExtraInfo(m_extraInfoNames[j])) {
171  if (particle->getExtraInfo(m_extraInfoNames[j]) != targetValues[j]) {
172  B2WARNING("Extra Info with given name is already set! Overwriting old value!");
173  particle->setExtraInfo(m_extraInfoNames[j], targetValues[j]);
174  }
175  } else {
176  particle->addExtraInfo(m_extraInfoNames[j], targetValues[j]);
177  }
178  }
179  }
180  }
181  if (m_listNames.empty()) {
182  StoreObjPtr<EventExtraInfo> eventExtraInfo;
183  if (not eventExtraInfo.isValid())
184  eventExtraInfo.create();
185  std::vector<float> targetValues = analyse(nullptr);
186  for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
187  if (eventExtraInfo->hasExtraInfo(m_extraInfoNames[j])) {
188  B2WARNING("Extra Info with given name is already set! I won't set it again!");
189  } else {
190  eventExtraInfo->addExtraInfo(m_extraInfoNames[j], targetValues[j]);
191  }
192  }
193  }
194  }
195 
197 } // Belle2 namespace
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition: DataStore.h:59
This module adds an ExtraInfo to the Particle objects in a given ParticleList.
std::vector< std::unique_ptr< MVA::Expert > > m_experts
Vector of pointers to the current MVA Experts.
std::vector< std::unique_ptr< MVA::SingleDataset > > m_datasets
Vector of pointers to the current input datasets.
std::vector< std::vector< const Variable::Manager::Var * > > m_individual_feature_variables
Vector of pointers to the feature variables for each expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::vector< std::string > m_identifiers
weight-files
std::vector< std::string > m_listNames
input particle list names
std::map< const Variable::Manager::Var *, float > m_feature_variables
Map containing the values of all needed feature variables.
std::vector< std::string > m_extraInfoNames
Names under which the SignalProbability is stored in the extraInfo of the Particle object.
std::vector< std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > > m_weightfile_representations
Vector of database pointers to the Database representation of the weightfile.
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:250
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:66
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:205
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:94
Base class for Modules.
Definition: Module.h:72
Class to store reconstructed particles.
Definition: Particle.h:74
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:95
Global list of available variables.
Definition: Manager.h:98
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
void init_mva(MVA::Weightfile &weightfile, unsigned int i)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
virtual void initialize() override
Initialize the module.
std::vector< float > analyse(Particle *)
Calculates expert output for given Particle pointer.
virtual void event() override
Called for each event.
virtual void beginRun() override
Called at the beginning of a new run.
Abstract base class for different kinds of events.