Belle II Software  release-08-01-10
MVAPrototypeModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 
10 #include <mva/modules/MVAExpert/MVAPrototypeModule.h>
11 #include <mva/interface/Interface.h>
12 
13 #include <boost/algorithm/string/predicate.hpp>
14 #include <memory>
15 
16 using namespace Belle2;
17 
18 REG_MODULE(MVAPrototype);
19 
21 {
22  setDescription("Prototype of a module which uses the MVA package");
23 
24  // Usually it is save to execute the MVA expert in Parallel mode,
25  // but ultimately this depends on the backend you use.
26  // The default method FastBDT has no problem with parallel execution.
28 
29  // Your module probably has a parameter which defines the database identifier or filename of the mva weightfile.
30  // Of course you could also hard-code it in the source code if this is not configurable
31  addParam("identifier", m_identifier, "The database identifier or filename which is used to load the weights during the training.");
32 
33  // For classification it is often useful to be able to change the signalFraction.
34  // If the signal fraction in the training is different from the dataset you want to apply it to,
35  // you have to pass the correct signalFraction otherwise you cannot interpret the output of the classifier as a probability.
36  // On the other hand, if you don't require the output to be a probability you don't have to care about this.
37  addParam("signalFraction", m_signal_fraction_override,
38  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
39 }
40 
42 {
43  // If the identifier does not end on .root or .xml, we are dealing with a database identifier
44  // so we need to create a DBObjPtr, which will fetch the weightfile from the database
45  if (not(boost::ends_with(m_identifier, ".root") or boost::ends_with(m_identifier, ".xml"))) {
46  m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(m_identifier);
47  }
48 
49  // The supported methods have to be initialized once (calling it more than once is save)
51 
52 }
53 
55 {
56 
57  // If the DBObjPtr is valid we are dealing with a weightfile from the database
59  // We check if the weightfile changed and we have to update the expert
60  if (m_weightfile_representation->hasChanged()) {
61  // The actual weightfile is stored in the m_data field of the m_weightfile_representation
62  std::stringstream ss((*m_weightfile_representation)->m_data);
63  auto weightfile = MVA::Weightfile::loadFromStream(ss);
64  init_mva(weightfile);
65  }
66  // In case of a file-based weightfile we load it here
67  // in principal this could be done in initialize as well
68  } else {
70  init_mva(weightfile);
71  }
72 
73 }
74 
76 {
77  // This function initializes the MVA::Expert using the provided weightfile
78 
79  // First we get the GeneralOptions from the weightfile
80  // and update the signal_fraction if required
81  MVA::GeneralOptions general_options;
82  weightfile.getOptions(general_options);
85 
86  // Secondly we load all supported interfaces, and fetch the correct MVA::Expert
87  // and load the weightfile into this expert
88  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
89  m_expert = supported_interfaces[general_options.m_method]->getExpert();
90  m_expert->load(weightfile);
91 
92  // Finally, we create an MVA::SingleDataset, in which we will save our features later
93  // to pass them to the expert.
94  // If you want to apply the expert to more than one sample, you can also use
95  // MVA::MultiDataset or any other Dataset defined by the mva package interface.
96  std::vector<float> dummy(general_options.m_variables.size(), 0);
97  m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
98 
99 }
100 
102 {
103  // Just to be save we check again if the MVA::Expert is loaded
104  // It can happen that for example the database doesn't find the payload
105  // and the expert ends up uninitialized.
106  if (not m_expert) {
107  B2ERROR("MVA Expert is not loaded! I will return 0");
108  return;
109  }
110 
111  // You have to fill the dataset with your data.
112  // The order must be the same as the order of the variables in general_options.m_variables
113  for (unsigned int i = 0; i < m_dataset->getNumberOfFeatures(); ++i) {
114  m_dataset->m_input[i] = 1.0;
115  }
116 
117  // All what is left to do is applying the expert to the dataset
118  // it will return an std::vector with the results, with one entry per sample.
119  // The MVA::SingleDataset only contains one entry, so we are interested only in the first entry here.
120  // The MVA::MultiDataset on the other hand would have more than one entry in the returned vector of apply.
121  float probability = m_expert->apply(*m_dataset)[0];
122  B2INFO("The probability is " << probability);
123 }
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
virtual void initialize() override
Initialize the module.
virtual void event() override
Called for each event.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
virtual void beginRun() override
Called at the beginning of a new run.
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::string m_identifier
database identifier or filename of the weightfile
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
REG_MODULE(arichBtest)
Register the Module.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
Abstract base class for different kinds of events.