Belle II Software  release-06-00-14
MVAPrototypeModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 
10 #include <mva/modules/MVAExpert/MVAPrototypeModule.h>
11 #include <mva/interface/Interface.h>
12 
13 #include <boost/algorithm/string/predicate.hpp>
14 #include <memory>
15 
16 namespace Belle2 {
22  REG_MODULE(MVAPrototype)
23 
25  {
26  setDescription("Prototype of a module which uses the MVA package");
27 
28  // Usually it is save to execute the MVA expert in Parallel mode,
29  // but ultimately this depends on the backend you use.
30  // The default method FastBDT has no problem with parallel execution.
31  setPropertyFlags(c_ParallelProcessingCertified);
32 
33  // Your module probably has a parameter which defines the database identifier or filename of the mva weightfile.
34  // Of course you could also hard-code it in the source code if this is not configurable
35  addParam("identifier", m_identifier, "The database identifier or filename which is used to load the weights during the training.");
36 
37  // For classification it is often useful to be able to change the signalFraction.
38  // If the signal fraction in the training is different from the dataset you want to apply it to,
39  // you have to pass the correct signalFraction otherwise you cannot interpret the output of the classifier as a probability.
40  // On the other hand, if you don't require the output to be a probability you don't have to care about this.
41  addParam("signalFraction", m_signal_fraction_override,
42  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
43  }
44 
46  {
47  // If the identifier does not end on .root or .xml, we are dealing with a database identifier
48  // so we need to create a DBObjPtr, which will fetch the weightfile from the database
49  if (not(boost::ends_with(m_identifier, ".root") or boost::ends_with(m_identifier, ".xml"))) {
50  m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(m_identifier);
51  }
52 
53  // The supported methods have to be initialized once (calling it more than once is save)
55 
56  }
57 
59  {
60 
61  // If the DBObjPtr is valid we are dealing with a weightfile from the database
63  // We check if the weightfile changed and we have to update the expert
64  if (m_weightfile_representation->hasChanged()) {
65  // The actual weightfile is stored in the m_data field of the m_weightfile_representation
66  std::stringstream ss((*m_weightfile_representation)->m_data);
67  auto weightfile = MVA::Weightfile::loadFromStream(ss);
68  init_mva(weightfile);
69  }
70  // In case of a file-based weightfile we load it here
71  // in principal this could be done in initialize as well
72  } else {
74  init_mva(weightfile);
75  }
76 
77  }
78 
80  {
81  // This function initializes the MVA::Expert using the provided weightfile
82 
83  // First we get the GeneralOptions from the weightfile
84  // and update the signal_fraction if required
85  MVA::GeneralOptions general_options;
86  weightfile.getOptions(general_options);
89 
90  // Secondly we load all supported interfaces, and fetch the correct MVA::Expert
91  // and load the weightfile into this expert
92  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
93  m_expert = supported_interfaces[general_options.m_method]->getExpert();
94  m_expert->load(weightfile);
95 
96  // Finally, we create an MVA::SingleDataset, in which we will save our features later
97  // to pass them to the expert.
98  // If you want to apply the expert to more than one sample, you can also use
99  // MVA::MultiDataset or any other Dataset defined by the mva package interface.
100  std::vector<float> dummy(general_options.m_variables.size(), 0);
101  m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
102 
103  }
104 
106  {
107  // Just to be save we check again if the MVA::Expert is loaded
108  // It can happen that for example the database doesn't find the payload
109  // and the expert ends up uninitialized.
110  if (not m_expert) {
111  B2ERROR("MVA Expert is not loaded! I will return 0");
112  return;
113  }
114 
115  // You have to fill the dataset with your data.
116  // The order must be the same as the order of the variables in general_options.m_variables
117  for (unsigned int i = 0; i < m_dataset->getNumberOfFeatures(); ++i) {
118  m_dataset->m_input[i] = 1.0;
119  }
120 
121  // All what is left to do is applying the expert to the dataset
122  // it will return an std::vector with the results, with one entry per sample.
123  // The MVA::SingleDataset only contains one entry, so we are interested only in the first entry here.
124  // The MVA::MultiDataset on the other hand would have more than one entry in the returned vector of apply.
125  float probability = m_expert->apply(*m_dataset)[0];
126  B2INFO("The probability is " << probability);
127  }
128 
130 } // Belle2 namespace
131 
This module can be used as a prototype for your own module which uses MVA weightfiles.
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
std::string m_identifier
database identifier or filename of the weightfile
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:250
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:66
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:205
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:94
Base class for Modules.
Definition: Module.h:72
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
virtual void initialize() override
Initialize the module.
virtual void event() override
Called for each event.
virtual void beginRun() override
Called at the beginning of a new run.
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
Abstract base class for different kinds of events.