Belle II Software light-2406-ragdoll
MVAPrototypeModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9
10#include <mva/modules/MVAExpert/MVAPrototypeModule.h>
11#include <mva/interface/Interface.h>
12
13#include <boost/algorithm/string/predicate.hpp>
14#include <memory>
15
16using namespace Belle2;
17
18REG_MODULE(MVAPrototype);
19
21{
22 setDescription("Prototype of a module which uses the MVA package");
23
24 // Usually it is save to execute the MVA expert in Parallel mode,
25 // but ultimately this depends on the backend you use.
26 // The default method FastBDT has no problem with parallel execution.
28
29 // Your module probably has a parameter which defines the database identifier or filename of the mva weightfile.
30 // Of course you could also hard-code it in the source code if this is not configurable
31 addParam("identifier", m_identifier, "The database identifier or filename which is used to load the weights during the training.");
32
33 // For classification it is often useful to be able to change the signalFraction.
34 // If the signal fraction in the training is different from the dataset you want to apply it to,
35 // you have to pass the correct signalFraction otherwise you cannot interpret the output of the classifier as a probability.
36 // On the other hand, if you don't require the output to be a probability you don't have to care about this.
37 addParam("signalFraction", m_signal_fraction_override,
38 "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
39}
40
42{
43 // If the identifier does not end on .root or .xml, we are dealing with a database identifier
44 // so we need to create a DBObjPtr, which will fetch the weightfile from the database
45 if (not(boost::ends_with(m_identifier, ".root") or boost::ends_with(m_identifier, ".xml"))) {
46 m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(m_identifier);
47 }
48
49 // The supported methods have to be initialized once (calling it more than once is save)
51
52}
53
55{
56
57 // If the DBObjPtr is valid we are dealing with a weightfile from the database
59 // We check if the weightfile changed and we have to update the expert
60 if (m_weightfile_representation->hasChanged()) {
61 // The actual weightfile is stored in the m_data field of the m_weightfile_representation
62 std::stringstream ss((*m_weightfile_representation)->m_data);
63 auto weightfile = MVA::Weightfile::loadFromStream(ss);
64 init_mva(weightfile);
65 }
66 // In case of a file-based weightfile we load it here
67 // in principal this could be done in initialize as well
68 } else {
70 init_mva(weightfile);
71 }
72
73}
74
76{
77 // This function initializes the MVA::Expert using the provided weightfile
78
79 // First we get the GeneralOptions from the weightfile
80 // and update the signal_fraction if required
81 MVA::GeneralOptions general_options;
82 weightfile.getOptions(general_options);
85
86 // Secondly we load all supported interfaces, and fetch the correct MVA::Expert
87 // and load the weightfile into this expert
88 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
89 m_expert = supported_interfaces[general_options.m_method]->getExpert();
90 m_expert->load(weightfile);
91
92 // Finally, we create an MVA::SingleDataset, in which we will save our features later
93 // to pass them to the expert.
94 // If you want to apply the expert to more than one sample, you can also use
95 // MVA::MultiDataset or any other Dataset defined by the mva package interface.
96 std::vector<float> dummy(general_options.m_variables.size(), 0);
97 m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
98
99}
100
102{
103 // Just to be save we check again if the MVA::Expert is loaded
104 // It can happen that for example the database doesn't find the payload
105 // and the expert ends up uninitialized.
106 if (not m_expert) {
107 B2ERROR("MVA Expert is not loaded! I will return 0");
108 return;
109 }
110
111 // You have to fill the dataset with your data.
112 // The order must be the same as the order of the variables in general_options.m_variables
113 for (unsigned int i = 0; i < m_dataset->getNumberOfFeatures(); ++i) {
114 m_dataset->m_input[i] = 1.0;
115 }
116
117 // All what is left to do is applying the expert to the dataset
118 // it will return an std::vector with the results, with one entry per sample.
119 // The MVA::SingleDataset only contains one entry, so we are interested only in the first entry here.
120 // The MVA::MultiDataset on the other hand would have more than one entry in the returned vector of apply.
121 float probability = m_expert->apply(*m_dataset)[0];
122 B2INFO("The probability is " << probability);
123}
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
virtual void initialize() override
Initialize the module.
virtual void event() override
Called for each event.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
virtual void beginRun() override
Called at the beginning of a new run.
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::string m_identifier
database identifier or filename of the weightfile
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24