Belle II Software  release-05-02-19
MVAMultipleExpertsModule.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck and Fernando Abudinen *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 
12 #include <mva/modules/MVAExpert/MVAMultipleExpertsModule.h>
13 
14 #include <analysis/dataobjects/Particle.h>
15 #include <analysis/dataobjects/ParticleList.h>
16 #include <analysis/dataobjects/ParticleExtraInfoMap.h>
17 #include <analysis/dataobjects/EventExtraInfo.h>
18 
19 #include <mva/interface/Interface.h>
20 
21 #include <boost/algorithm/string/predicate.hpp>
22 #include <memory>
23 
24 #include <framework/logging/Logger.h>
25 
26 
27 namespace Belle2 {
33  REG_MODULE(MVAMultipleExperts)
34 
36  {
37  setDescription("Adds ExtraInfos to the Particle objects in given ParticleLists which is calcuated by multiple experts defined by the given weightfiles.");
39 
40  std::vector<std::string> empty;
41  addParam("listNames", m_listNames,
42  "Particles from these ParticleLists are used as input. If no name is given the experts are applied to every event once, and one can only use variables which accept nullptr as Particle*",
43  empty);
44  addParam("extraInfoNames", m_extraInfoNames,
45  "Names under which the output of the experts is stored in the ExtraInfo of the Particle object.");
46  addParam("identifiers", m_identifiers, "The database identifiers which is used to load the weights during the training.");
47  addParam("signalFraction", m_signal_fraction_override,
48  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
49  }
50 
52  {
53  // All specified ParticleLists are required to exist
54  for (auto& name : m_listNames) {
55  StoreObjPtr<ParticleList> list(name);
56  list.isRequired();
57  }
58 
59  if (m_listNames.empty()) {
61  extraInfo.registerInDataStore();
62  } else {
64  extraInfo.registerInDataStore();
65  }
66 
67  if (m_extraInfoNames.size() != m_identifiers.size()) {
68  B2FATAL("The number of given m_extraInfoNames is not equal to the number of m_identifiers. The output the ith method in m_identifiers is saved as extraInfo under the ith name in m_extraInfoNames! Set also different names for each method!");
69  }
70 
72  m_experts.resize(m_identifiers.size());
74  m_datasets.resize(m_identifiers.size());
75 
76  for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
77  if (not(boost::ends_with(m_identifiers[i], ".root") or boost::ends_with(m_identifiers[i], ".xml"))) {
78  m_weightfile_representations[i] = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
79  MVA::makeSaveForDatabase(m_identifiers[i]));
80  }
81  }
82 
84 
85  }
86 
88  {
89 
90  if (!m_weightfile_representations.empty()) {
91  for (unsigned int i = 0; i < m_weightfile_representations.size(); ++i) {
93  if (m_weightfile_representations[i]->hasChanged()) {
94  std::stringstream ss((*m_weightfile_representations[i])->m_data);
95  auto weightfile = MVA::Weightfile::loadFromStream(ss);
96  init_mva(weightfile, i);
97  }
98  } else {
99  auto weightfile = MVA::Weightfile::loadFromFile(m_identifiers[i]);
100  init_mva(weightfile, i);
101  }
102  }
103 
104  } else B2FATAL("No m_identifiers given. At least one is needed!");
105  }
106 
107  void MVAMultipleExpertsModule::init_mva(MVA::Weightfile& weightfile, unsigned int i)
108  {
109 
110  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
111  Variable::Manager& manager = Variable::Manager::Instance();
112 
113 
114  MVA::GeneralOptions general_options;
115  weightfile.getOptions(general_options);
116 
117  // Overwrite signal fraction from training
119  weightfile.addSignalFraction(m_signal_fraction_override);
120 
121  m_experts[i] = supported_interfaces[general_options.m_method]->getExpert();
122  m_experts[i]->load(weightfile);
123 
124 
125  m_individual_feature_variables[i] = manager.getVariables(general_options.m_variables);
126  if (m_individual_feature_variables[i].size() != general_options.m_variables.size()) {
127  B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
128  }
129 
130  for (auto& iVariable : m_individual_feature_variables[i]) {
131  if (m_feature_variables.find(iVariable) == m_feature_variables.end()) {
132  m_feature_variables.insert(std::pair<const Variable::Manager::Var*, float>(iVariable, 0));
133  }
134  }
135 
136  std::vector<float> dummy;
137  dummy.resize(m_individual_feature_variables[i].size(), 0);
138  m_datasets[i] = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
139 
140  }
141 
142  std::vector<float> MVAMultipleExpertsModule::analyse(Particle* particle)
143  {
144 
145  std::vector<float> targetValues;
146  targetValues.resize(m_identifiers.size());
147  for (auto const& iVariable : m_feature_variables) {
148  m_feature_variables[iVariable.first] = iVariable.first ->function(particle);
149  }
150 
151  for (unsigned int i = 0; i < m_identifiers.size(); ++i) {
152  for (unsigned int j = 0; j < m_individual_feature_variables[i].size(); ++j) {
154  }
155  targetValues[i] = m_experts[i]->apply(*m_datasets[i])[0];
156  }
157 
158  return targetValues;
159  }
160 
161 
163  {
164  for (auto& listName : m_listNames) {
165  StoreObjPtr<ParticleList> list(listName);
166  // Calculate target Value for Particles
167  for (unsigned i = 0; i < list->getListSize(); ++i) {
168  Particle* particle = list->getParticle(i);
169  std::vector<float> targetValues = analyse(particle);
170 
171  for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
172  if (particle->hasExtraInfo(m_extraInfoNames[j])) {
173  B2WARNING("Extra Info with given name is already set! Overwriting old value!");
174  particle->setExtraInfo(m_extraInfoNames[j], targetValues[j]);
175  } else {
176  particle->addExtraInfo(m_extraInfoNames[j], targetValues[j]);
177  }
178  }
179  }
180  }
181  if (m_listNames.empty()) {
182  StoreObjPtr<EventExtraInfo> eventExtraInfo;
183  if (not eventExtraInfo.isValid())
184  eventExtraInfo.create();
185  std::vector<float> targetValues = analyse(nullptr);
186  for (unsigned int j = 0; j < m_identifiers.size(); ++j) {
187  if (eventExtraInfo->hasExtraInfo(m_extraInfoNames[j])) {
188  B2WARNING("Extra Info with given name is already set! I won't set it again!");
189  } else {
190  eventExtraInfo->addExtraInfo(m_extraInfoNames[j], targetValues[j]);
191  }
192  }
193  }
194  }
195 
197 } // Belle2 namespace
Belle2::MVAMultipleExpertsModule::init_mva
void init_mva(MVA::Weightfile &weightfile, unsigned int i)
Initialize mva expert, dataset and features Called everytime the weightfile in the database changes i...
Definition: MVAMultipleExpertsModule.cc:115
Belle2::MVAMultipleExpertsModule::m_datasets
std::vector< std::unique_ptr< MVA::SingleDataset > > m_datasets
Vector of pointers to the current input datasets.
Definition: MVAMultipleExpertsModule.h:109
Belle2::MVA::AbstractInterface::getSupportedInterfaces
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:55
Belle2::Module::setDescription
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:216
REG_MODULE
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:652
Belle2::Module::c_ParallelProcessingCertified
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:82
Belle2::MVAMultipleExpertsModule::m_experts
std::vector< std::unique_ptr< MVA::Expert > > m_experts
Vector of pointers to the current MVA Experts.
Definition: MVAMultipleExpertsModule.h:107
Belle2::MVAMultipleExpertsModule::m_signal_fraction_override
double m_signal_fraction_override
Signal Fraction which should be used.
Definition: MVAMultipleExpertsModule.h:94
Belle2::MVAExpertModule::m_listNames
std::vector< std::string > m_listNames
input particle list names
Definition: MVAExpertModule.h:88
Belle2::MVAMultipleExpertsModule::initialize
virtual void initialize() override
Initialize the module.
Definition: MVAMultipleExpertsModule.cc:59
Belle2::MVAMultipleExpertsModule::MVAMultipleExpertsModule
MVAMultipleExpertsModule()
Constructor.
Definition: MVAMultipleExpertsModule.cc:43
Belle2::Module::setPropertyFlags
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:210
Belle2::MVAMultipleExpertsModule::m_extraInfoNames
std::vector< std::string > m_extraInfoNames
Names under which the SignalProbability is stored in the extraInfo of the Particle object.
Definition: MVAMultipleExpertsModule.h:93
Belle2::MVAMultipleExpertsModule::analyse
std::vector< float > analyse(Particle *)
Calculates expert output for given Particle pointer.
Definition: MVAMultipleExpertsModule.cc:150
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::StoreObjPtr
Type-safe access to single objects in the data store.
Definition: ParticleList.h:33
Belle2::MVA::Weightfile::loadFromFile
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:215
Belle2::MVAMultipleExpertsModule::m_identifiers
std::vector< std::string > m_identifiers
weight-files
Definition: MVAMultipleExpertsModule.h:91
Belle2::MVA::AbstractInterface::initSupportedInterfaces
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:55
Belle2::Module::addParam
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:562
Belle2::MVAExpertModule::m_signal_fraction_override
double m_signal_fraction_override
Signal Fraction which should be used.
Definition: MVAExpertModule.h:91
Belle2::MVAMultipleExpertsModule::m_weightfile_representations
std::vector< std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > > m_weightfile_representations
Vector of database pointers to the Database representation of the weightfile.
Definition: MVAMultipleExpertsModule.h:105
Belle2::MVA::Weightfile::loadFromStream
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:260
Belle2::MVAMultipleExpertsModule::m_listNames
std::vector< std::string > m_listNames
input particle list names
Definition: MVAMultipleExpertsModule.h:90
Belle2::MVAMultipleExpertsModule::beginRun
virtual void beginRun() override
Called at the beginning of a new run.
Definition: MVAMultipleExpertsModule.cc:95
Belle2::MVAMultipleExpertsModule::event
virtual void event() override
Called for each event.
Definition: MVAMultipleExpertsModule.cc:170
Belle2::MVAMultipleExpertsModule::m_feature_variables
std::map< const Variable::Manager::Var *, float > m_feature_variables
Map containing the values of all needed feature variables.
Definition: MVAMultipleExpertsModule.h:102
Belle2::DataStore::c_Event
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition: DataStore.h:61
Belle2::MVAMultipleExpertsModule::m_individual_feature_variables
std::vector< std::vector< const Variable::Manager::Var * > > m_individual_feature_variables
Vector of pointers to the feature variables for each expert.
Definition: MVAMultipleExpertsModule.h:97
Belle2::Variable::Manager::Instance
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:27