Belle II Software  release-05-02-19
MVAPrototypeModule.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 
12 #include <mva/modules/MVAExpert/MVAPrototypeModule.h>
13 #include <mva/interface/Interface.h>
14 
15 #include <boost/algorithm/string/predicate.hpp>
16 #include <memory>
17 
18 namespace Belle2 {
24  REG_MODULE(MVAPrototype)
25 
27  {
28  setDescription("Prototype of a module which uses the MVA package");
29 
30  // Usually it is save to execute the MVA expert in Parallel mode,
31  // but ultimately this depends on the backend you use.
32  // The default method FastBDT has no problem with parallel execution.
34 
35  // Your module probably has a parameter which defines the database identifier or filename of the mva weightfile.
36  // Of course you could also hard-code it in the source code if this is not configurable
37  addParam("identifier", m_identifier, "The database identifier or filename which is used to load the weights during the training.");
38 
39  // For classification it is often useful to be able to change the signalFraction.
40  // If the signal fraction in the training is different from the dataset you want to apply it to,
41  // you have to pass the correct signalFraction otherwise you cannot interpret the output of the classifier as a probability.
42  // On the other hand, if you don't require the output to be a probability you don't have to care about this.
43  addParam("signalFraction", m_signal_fraction_override,
44  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
45  }
46 
48  {
49  // If the identifier does not end on .root or .xml, we are dealing with a database identifier
50  // so we need to create a DBObjPtr, which will fetch the weightfile from the database
51  if (not(boost::ends_with(m_identifier, ".root") or boost::ends_with(m_identifier, ".xml"))) {
52  m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(m_identifier);
53  }
54 
55  // The supported methods have to be initialized once (calling it more than once is save)
57 
58  }
59 
61  {
62 
63  // If the DBObjPtr is valid we are dealing with a weightfile from the database
65  // We check if the weightfile changed and we have to update the expert
66  if (m_weightfile_representation->hasChanged()) {
67  // The actual weightfile is stored in the m_data field of the m_weightfile_representation
68  std::stringstream ss((*m_weightfile_representation)->m_data);
69  auto weightfile = MVA::Weightfile::loadFromStream(ss);
70  init_mva(weightfile);
71  }
72  // In case of a file-based weightfile we load it here
73  // in principal this could be done in initialize as well
74  } else {
76  init_mva(weightfile);
77  }
78 
79  }
80 
81  void MVAPrototypeModule::init_mva(MVA::Weightfile& weightfile)
82  {
83  // This function initializes the MVA::Expert using the provided weightfile
84 
85  // First we get the GeneralOptions from the weightfile
86  // and update the signal_fraction if required
87  MVA::GeneralOptions general_options;
88  weightfile.getOptions(general_options);
90  weightfile.addSignalFraction(m_signal_fraction_override);
91 
92  // Secondly we load all supported interfaces, and fetch the correct MVA::Expert
93  // and load the weightfile into this expert
94  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
95  m_expert = supported_interfaces[general_options.m_method]->getExpert();
96  m_expert->load(weightfile);
97 
98  // Finally, we create an MVA::SingleDataset, in which we will save our features later
99  // to pass them to the expert.
100  // If you want to apply the expert to more than one sample, you can also use
101  // MVA::MultiDataset or any other Dataset defined by the mva package interface.
102  std::vector<float> dummy(general_options.m_variables.size(), 0);
103  m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
104 
105  }
106 
108  {
109  // Just to be save we check again if the MVA::Expert is loaded
110  // It can happen that for example the database doesn't find the payload
111  // and the expert ends up uninitialized.
112  if (not m_expert) {
113  B2ERROR("MVA Expert is not loaded! I will return 0");
114  return;
115  }
116 
117  // You have to fill the dataset with your data.
118  // The order must be the same as the order of the variables in general_options.m_variables
119  for (unsigned int i = 0; i < m_dataset->getNumberOfFeatures(); ++i) {
120  m_dataset->m_input[i] = 1.0;
121  }
122 
123  // All what is left to do is applying the expert to the dataset
124  // it will return an std::vector with the results, with one entry per sample.
125  // The MVA::SingleDataset only containts one entry, so we are interested only in the first entry here.
126  // The MVA::MultiDataset on the other hand would have more than one entry in the returned vector of apply.
127  float probability = m_expert->apply(*m_dataset)[0];
128  B2INFO("The probability is " << probability);
129  }
130 
132 } // Belle2 namespace
133 
Belle2::MVAPrototypeModule::m_expert
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
Definition: MVAPrototypeModule.h:74
Belle2::MVA::AbstractInterface::getSupportedInterfaces
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:55
Belle2::Module::setDescription
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:216
REG_MODULE
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:652
Belle2::Module::c_ParallelProcessingCertified
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:82
Belle2::MVAMultipleExpertsModule::m_signal_fraction_override
double m_signal_fraction_override
Signal Fraction which should be used.
Definition: MVAMultipleExpertsModule.h:94
Belle2::MVAPrototypeModule::initialize
virtual void initialize() override
Initialize the module.
Definition: MVAPrototypeModule.cc:55
Belle2::Module::setPropertyFlags
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:210
Belle2::MVAPrototypeModule::init_mva
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called everytime the weightfile in the database changes i...
Definition: MVAPrototypeModule.cc:89
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::Weightfile::loadFromFile
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:215
Belle2::MVAPrototypeModule::m_signal_fraction_override
double m_signal_fraction_override
Signal Fraction which should be used.
Definition: MVAPrototypeModule.h:70
Belle2::MVAPrototypeModule::m_weightfile_representation
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
Definition: MVAPrototypeModule.h:73
Belle2::MVAPrototypeModule::m_identifier
std::string m_identifier
database identifier or filename of the weightfile
Definition: MVAPrototypeModule.h:69
Belle2::MVA::AbstractInterface::initSupportedInterfaces
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:55
Belle2::Module::addParam
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:562
Belle2::MVA::Weightfile::loadFromStream
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:260
Belle2::MVAPrototypeModule::beginRun
virtual void beginRun() override
Called at the beginning of a new run.
Definition: MVAPrototypeModule.cc:68
Belle2::MVAPrototypeModule::event
virtual void event() override
Called for each event.
Definition: MVAPrototypeModule.cc:115
Belle2::MVAPrototypeModule::MVAPrototypeModule
MVAPrototypeModule()
Constructor.
Definition: MVAPrototypeModule.cc:34
Belle2::MVAPrototypeModule::m_dataset
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
Definition: MVAPrototypeModule.h:75