Belle II Software development
MVAPrototypeModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9
10#include <mva/modules/MVAExpert/MVAPrototypeModule.h>
11#include <mva/interface/Interface.h>
12
13#include <memory>
14
15using namespace Belle2;
16
17REG_MODULE(MVAPrototype);
18
20{
21 setDescription("Prototype of a module which uses the MVA package");
22
23 // Usually it is save to execute the MVA expert in Parallel mode,
24 // but ultimately this depends on the backend you use.
25 // The default method FastBDT has no problem with parallel execution.
27
28 // Your module probably has a parameter which defines the database identifier or filename of the mva weightfile.
29 // Of course you could also hard-code it in the source code if this is not configurable
30 addParam("identifier", m_identifier, "The database identifier or filename which is used to load the weights during the training.");
31
32 // For classification it is often useful to be able to change the signalFraction.
33 // If the signal fraction in the training is different from the dataset you want to apply it to,
34 // you have to pass the correct signalFraction otherwise you cannot interpret the output of the classifier as a probability.
35 // On the other hand, if you don't require the output to be a probability you don't have to care about this.
36 addParam("signalFraction", m_signal_fraction_override,
37 "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
38}
39
41{
42 // If the identifier does not end on .root or .xml, we are dealing with a database identifier
43 // so we need to create a DBObjPtr, which will fetch the weightfile from the database
44 if (not(m_identifier.ends_with(".root") or m_identifier.ends_with(".xml"))) {
45 m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(m_identifier);
46 }
47
48 // The supported methods have to be initialized once (calling it more than once is save)
50
51}
52
54{
55
56 // If the DBObjPtr is valid we are dealing with a weightfile from the database
58 // We check if the weightfile changed and we have to update the expert
59 if (m_weightfile_representation->hasChanged()) {
60 // The actual weightfile is stored in the m_data field of the m_weightfile_representation
61 std::stringstream ss((*m_weightfile_representation)->m_data);
62 auto weightfile = MVA::Weightfile::loadFromStream(ss);
63 init_mva(weightfile);
64 }
65 // In case of a file-based weightfile we load it here
66 // in principal this could be done in initialize as well
67 } else {
69 init_mva(weightfile);
70 }
71
72}
73
75{
76 // This function initializes the MVA::Expert using the provided weightfile
77
78 // First we get the GeneralOptions from the weightfile
79 // and update the signal_fraction if required
80 MVA::GeneralOptions general_options;
81 weightfile.getOptions(general_options);
83 weightfile.addSignalFraction(m_signal_fraction_override);
84
85 // Secondly we load all supported interfaces, and fetch the correct MVA::Expert
86 // and load the weightfile into this expert
87 auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
88 m_expert = supported_interfaces[general_options.m_method]->getExpert();
89 m_expert->load(weightfile);
90
91 // Finally, we create an MVA::SingleDataset, in which we will save our features later
92 // to pass them to the expert.
93 // If you want to apply the expert to more than one sample, you can also use
94 // MVA::MultiDataset or any other Dataset defined by the mva package interface.
95 std::vector<float> dummy(general_options.m_variables.size(), 0);
96 m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
97
98}
99
101{
102 // Just to be save we check again if the MVA::Expert is loaded
103 // It can happen that for example the database doesn't find the payload
104 // and the expert ends up uninitialized.
105 if (not m_expert) {
106 B2ERROR("MVA Expert is not loaded! I will return 0");
107 return;
108 }
109
110 // You have to fill the dataset with your data.
111 // The order must be the same as the order of the variables in general_options.m_variables
112 for (unsigned int i = 0; i < m_dataset->getNumberOfFeatures(); ++i) {
113 m_dataset->m_input[i] = 1.0;
114 }
115
116 // All what is left to do is applying the expert to the dataset
117 // it will return an std::vector with the results, with one entry per sample.
118 // The MVA::SingleDataset only contains one entry, so we are interested only in the first entry here.
119 // The MVA::MultiDataset on the other hand would have more than one entry in the returned vector of apply.
120 float probability = m_expert->apply(*m_dataset)[0];
121 B2INFO("The probability is " << probability);
122}
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
virtual void initialize() override
Initialize the module.
virtual void event() override
Called for each event.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
virtual void beginRun() override
Called at the beginning of a new run.
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
std::string m_identifier
database identifier or filename of the weightfile
static void initSupportedInterfaces()
Static function which initializes all supported interfaces, has to be called once before getSupported...
Definition Interface.cc:46
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition Interface.h:53
General options which are shared by all MVA trainings.
Definition Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
Module()
Constructor.
Definition Module.cc:30
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
Abstract base class for different kinds of events.