Belle II Software  release-05-02-19
MVAExpertModule.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 
12 #include <mva/modules/MVAExpert/MVAExpertModule.h>
13 
14 #include <analysis/dataobjects/Particle.h>
15 #include <analysis/dataobjects/ParticleList.h>
16 #include <analysis/dataobjects/ParticleExtraInfoMap.h>
17 #include <analysis/dataobjects/EventExtraInfo.h>
18 
19 #include <mva/interface/Interface.h>
20 
21 #include <boost/algorithm/string/predicate.hpp>
22 #include <memory>
23 
24 #include <framework/logging/Logger.h>
25 
26 
27 namespace Belle2 {
33  REG_MODULE(MVAExpert)
34 
36  {
37  setDescription("Adds an ExtraInfo to the Particle objects in given ParticleLists which is calcuated by an expert defined by a weightfile.");
39 
40  std::vector<std::string> empty;
41  addParam("listNames", m_listNames,
42  "Particles from these ParticleLists are used as input. If no name is given the expert is applied to every event once, and one can only use variables which accept nullptr as Particle*",
43  empty);
44  addParam("extraInfoName", m_extraInfoName,
45  "Name under which the output of the expert is stored in the ExtraInfo of the Particle object.");
46  addParam("identifier", m_identifier, "The database identifier which is used to load the weights during the training.");
47  addParam("signalFraction", m_signal_fraction_override,
48  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
49  }
50 
52  {
53  // All specified ParticleLists are required to exist
54  for (auto& name : m_listNames) {
55  StoreObjPtr<ParticleList> list(name);
56  list.isRequired();
57  }
58 
59  if (m_listNames.empty()) {
61  extraInfo.registerInDataStore();
62  } else {
64  extraInfo.registerInDataStore();
65  }
66 
67  if (not(boost::ends_with(m_identifier, ".root") or boost::ends_with(m_identifier, ".xml"))) {
68  m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
69  MVA::makeSaveForDatabase(m_identifier));
70  }
72 
73  }
74 
76  {
77 
79  if (m_weightfile_representation->hasChanged()) {
80  std::stringstream ss((*m_weightfile_representation)->m_data);
81  auto weightfile = MVA::Weightfile::loadFromStream(ss);
82  init_mva(weightfile);
83  }
84  } else {
86  init_mva(weightfile);
87  }
88 
89  }
90 
91  void MVAExpertModule::init_mva(MVA::Weightfile& weightfile)
92  {
93 
94  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
95  MVA::GeneralOptions general_options;
96  weightfile.getOptions(general_options);
97 
98  // Overwrite signal fraction from training
100  weightfile.addSignalFraction(m_signal_fraction_override);
101 
102  m_expert = supported_interfaces[general_options.m_method]->getExpert();
103  m_expert->load(weightfile);
104 
106  m_feature_variables = manager.getVariables(general_options.m_variables);
107  if (m_feature_variables.size() != general_options.m_variables.size()) {
108  B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
109  }
110 
111  std::vector<float> dummy;
112  dummy.resize(m_feature_variables.size(), 0);
113  m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
114 
115  }
116 
117  float MVAExpertModule::analyse(Particle* particle)
118  {
119  if (not m_expert) {
120  B2ERROR("MVA Expert is not loaded! I will return 0");
121  return 0.0;
122  }
123  for (unsigned int i = 0; i < m_feature_variables.size(); ++i) {
124  m_dataset->m_input[i] = m_feature_variables[i]->function(particle);
125  }
126  return m_expert->apply(*m_dataset)[0];
127  }
128 
129 
131  {
132  for (auto& listName : m_listNames) {
133  StoreObjPtr<ParticleList> list(listName);
134  // Calculate target Value for Particles
135  for (unsigned i = 0; i < list->getListSize(); ++i) {
136  Particle* particle = list->getParticle(i);
137  float targetValue = analyse(particle);
138  if (particle->hasExtraInfo(m_extraInfoName)) {
139  B2WARNING("Extra Info with given name is already set! Overwriting old value!");
140  particle->setExtraInfo(m_extraInfoName, targetValue);
141  } else {
142  particle->addExtraInfo(m_extraInfoName, targetValue);
143  }
144  }
145  }
146  if (m_listNames.empty()) {
147  StoreObjPtr<EventExtraInfo> eventExtraInfo;
148  if (not eventExtraInfo.isValid())
149  eventExtraInfo.create();
150  if (eventExtraInfo->hasExtraInfo(m_extraInfoName)) {
151  B2WARNING("Extra Info with given name is already set! I won't set it again!");
152  } else {
153  float targetValue = analyse(nullptr);
154  eventExtraInfo->addExtraInfo(m_extraInfoName, targetValue);
155  }
156  }
157  }
158 
160 } // Belle2 namespace
Belle2::MVAExpertModule::m_identifier
std::string m_identifier
weight-file
Definition: MVAExpertModule.h:89
Belle2::MVAExpertModule::m_extraInfoName
std::string m_extraInfoName
Name under which the SignalProbability is stored in the extraInfo of the Particle object.
Definition: MVAExpertModule.h:90
Belle2::MVA::AbstractInterface::getSupportedInterfaces
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:55
Belle2::Module::setDescription
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:216
REG_MODULE
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:652
Belle2::Module::c_ParallelProcessingCertified
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:82
Belle2::MVAExpertModule::m_expert
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
Definition: MVAExpertModule.h:97
Belle2::MVAExpertModule::MVAExpertModule
MVAExpertModule()
Constructor.
Definition: MVAExpertModule.cc:43
Belle2::MVAExpertModule::m_listNames
std::vector< std::string > m_listNames
input particle list names
Definition: MVAExpertModule.h:88
Belle2::MVAExpertModule::initialize
virtual void initialize() override
Initialize the module.
Definition: MVAExpertModule.cc:59
Belle2::Module::setPropertyFlags
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:210
Belle2::MVAExpertModule::init_mva
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called everytime the weightfile in the database changes i...
Definition: MVAExpertModule.cc:99
Belle2::MVAExpertModule::m_weightfile_representation
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
Definition: MVAExpertModule.h:96
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::StoreObjPtr
Type-safe access to single objects in the data store.
Definition: ParticleList.h:33
Belle2::MVA::Weightfile::loadFromFile
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:215
Belle2::MVAPrototypeModule::m_signal_fraction_override
double m_signal_fraction_override
Signal Fraction which should be used.
Definition: MVAPrototypeModule.h:70
Belle2::MVAPrototypeModule::m_identifier
std::string m_identifier
database identifier or filename of the weightfile
Definition: MVAPrototypeModule.h:69
Belle2::MVA::AbstractInterface::initSupportedInterfaces
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:55
Belle2::Module::addParam
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:562
Belle2::MVAExpertModule::m_signal_fraction_override
double m_signal_fraction_override
Signal Fraction which should be used.
Definition: MVAExpertModule.h:91
Belle2::MVA::Weightfile::loadFromStream
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:260
Belle2::MVAExpertModule::analyse
float analyse(Particle *)
Calculates expert output for given Particle pointer.
Definition: MVAExpertModule.cc:125
Belle2::MVAExpertModule::m_dataset
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
Definition: MVAExpertModule.h:98
Belle2::MVAExpertModule::beginRun
virtual void beginRun() override
Called at the beginning of a new run.
Definition: MVAExpertModule.cc:83
Belle2::MVAExpertModule::event
virtual void event() override
Called for each event.
Definition: MVAExpertModule.cc:138
Belle2::MVAExpertModule::m_feature_variables
std::vector< const Variable::Manager::Var * > m_feature_variables
Pointers to the feature variables.
Definition: MVAExpertModule.h:93
Belle2::DataStore::c_Event
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition: DataStore.h:61
Belle2::Variable::Manager
Global list of available variables.
Definition: Manager.h:108
Belle2::Variable::Manager::Instance
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:27