Belle II Software development
MVAExpert.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#include <tracking/trackFindingCDC/mva/MVAExpert.h>
9
11#include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
12#include <mva/interface/Expert.h>
13#include <mva/interface/Weightfile.h>
14#include <framework/database/DBObjPtr.h>
15#include <boost/algorithm/string/predicate.hpp>
16
17namespace Belle2 {
23 namespace MVA {
24 class Expert;
25 class SingleDataset;
26 class Weightfile;
27 }
28
29 namespace TrackFindingCDC {
32
33 public:
34 Impl(const std::string& identifier, std::vector<Named<Float_t*>> namedVariables);
35 void initialize();
36 void beginRun();
37 std::unique_ptr<MVA::Weightfile> getWeightFile();
38 double predict();
39 std::vector<float> predict(float* /* test_data */, int /* nFeature */, int /* nRows */);
40 std::vector<std::string> getVariableNames();
41 private:
43 std::vector<Named<Float_t*> > m_allNamedVariables;
44
46 std::vector<Named<Float_t*> > m_selectedNamedVariables;
47
49 std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile> > m_weightfileRepresentation;
50
52 std::unique_ptr<MVA::Expert> m_expert;
53
55 std::unique_ptr<MVA::Dataset> m_dataset;
56
59
61 std::string m_identifier;
62 };
63 }
65}
66
68#include <mva/interface/Interface.h>
69
70#include <framework/utilities/FileSystem.h>
71#include <framework/logging/Logger.h>
72
73#include <algorithm>
74
75using namespace Belle2;
76using namespace TrackFindingCDC;
77
78MVAExpert::Impl::Impl(const std::string& identifier,
79 std::vector<Named<Float_t*> > namedVariables)
80 : m_allNamedVariables(std::move(namedVariables))
81 , m_identifier(identifier)
82{
83}
84
86{
88 using boost::algorithm::ends_with;
89 if (not m_weightfileRepresentation and
90 not(ends_with(m_identifier, ".root") or ends_with(m_identifier, ".xml"))) {
91 using DBWeightFileRepresentation = DBObjPtr<DatabaseRepresentationOfWeightfile>;
92 m_weightfileRepresentation = std::make_unique<DBWeightFileRepresentation>(m_identifier);
93 }
94}
95
97{
98 std::unique_ptr<MVA::Weightfile> weightfile = getWeightFile();
99 if (weightfile) {
100 if ((weightfile->getElement<std::string>("method") == "FastBDT" and
101 (weightfile->getElement<int>("FastBDT_version") == 1 or
102 weightfile->getElement<int>("FastBDT_version") == 2)) or
103 (weightfile->getElement<std::string>("method") == "Python")) {
104
105 int nExpectedVars = weightfile->getElement<int>("number_feature_variables");
106
107 m_selectedNamedVariables.clear();
108 for (int iVar = 0; iVar < nExpectedVars; ++iVar) {
109 std::string variableElementName = "variable" + std::to_string(iVar);
110 std::string expectedName = weightfile->getElement<std::string>(variableElementName);
111 auto itNamedVariable = std::find_if(m_allNamedVariables.begin(),
112 m_allNamedVariables.end(),
113 [expectedName](const Named<Float_t*>& namedVariable) {
114 return namedVariable.getName() == expectedName;
115 });
116
117 if (itNamedVariable == m_allNamedVariables.end()) {
118 B2ERROR("Variable name " << iVar << " mismatch for FastBDT. " <<
119 "Could not find expected variable '" << expectedName << "'");
120 }
121 m_selectedNamedVariables.push_back(*itNamedVariable);
122 }
123 B2ASSERT("Number of variables mismatch", nExpectedVars == static_cast<int>(m_selectedNamedVariables.size()));
124 } else {
125 B2WARNING("Unpacked new kind of classifier. Consider to extend the feature variable check. Identifier name: " << m_identifier
126 << "; method name: " << weightfile->getElement<std::string>("method"));
127 m_selectedNamedVariables = m_allNamedVariables;
128 }
129
130 std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
132 weightfile->getOptions(m_generalOptions);
133 m_expert = supportedInterfaces[m_generalOptions.m_method]->getExpert();
134 m_expert->load(*weightfile);
135
136 std::vector<float> dummy;
137 dummy.resize(m_selectedNamedVariables.size(), 0);
138 m_dataset = std::make_unique<MVA::SingleDataset>(m_generalOptions, std::move(dummy), 0);
139 } else {
140 B2ERROR("Could not find weight file for identifier " << m_identifier);
141 }
142}
143
144std::unique_ptr<MVA::Weightfile> MVAExpert::Impl::getWeightFile()
145{
146 if (m_weightfileRepresentation) {
147 std::stringstream ss((*m_weightfileRepresentation)->m_data);
148 return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromStream(ss));
149 } else {
150 std::string weightFilePath = FileSystem::findFile(m_identifier);
151 return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromFile(weightFilePath));
152 }
153}
154
156{
157 if (not m_expert) {
158 B2ERROR("MVA Expert is not loaded! I will return 0");
159 return NAN;
160 }
161
162 // Transfer the extracted values to the data set were the expert can find them
163 for (unsigned int i = 0; i < m_selectedNamedVariables.size(); ++i) {
164 m_dataset->m_input[i] = *m_selectedNamedVariables[i];
165 }
166 return m_expert->apply(*m_dataset)[0];
167}
168
169std::vector<float> MVAExpert::Impl::predict(float* test_data, int nFeature, int nRows)
170{
171 std::vector<std::vector<float>> spectators;
172 std::vector<std::vector <float> > data;
173 data.resize(nRows);
174 for (int iRow = 0; iRow < nRows; iRow += 1) {
175 data[iRow].resize(nFeature);
176 for (int iFeature = 0; iFeature < nFeature; iFeature += 1) {
177 data[iRow][iFeature] = test_data[nFeature * iRow + iFeature];
178 }
179 }
180
181 MVA::MultiDataset dataSet(m_generalOptions, data, spectators);
182 return m_expert->apply(dataSet);
183}
184
185std::vector<std::string> MVAExpert::Impl::getVariableNames()
186{
187 std::vector<std::string> out(m_selectedNamedVariables.size());
188 for (size_t iName = 0; iName < m_selectedNamedVariables.size(); iName += 1) {
189 out[iName] = m_selectedNamedVariables[iName].getName();
190 }
191 return out;
192}
193
194
196// Silence Doxygen which is complaining that "no matching class member found for"
197// But there should be a better way that I just don't know of / find
199MVAExpert::MVAExpert(const std::string& identifier,
200 std::vector<Named<Float_t*> > namedVariables)
201 : m_impl(std::make_unique<MVAExpert::Impl>(identifier, std::move(namedVariables)))
203{
204}
205
206MVAExpert::~MVAExpert() = default;
207
209{
210 return m_impl->initialize();
211}
212
214{
215 return m_impl->beginRun();
216}
217
219{
220 return m_impl->predict();
221}
222
223std::vector<float> MVAExpert::predict(float* test_data, int nFeature, int nRows)
224{
225 return m_impl->predict(test_data, nFeature, nRows);
226}
227
228std::vector<std::string> MVAExpert::getVariableNames()
229{
230 return m_impl->getVariableNames();
231}
232
Class for accessing objects in the database.
Definition: DBObjPtr.h:21
Database representation of a Weightfile object.
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:151
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
General options which are shared by all MVA trainings.
Definition: Options.h:62
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
Implementation of the class to interact with the MVA package.
Definition: MVAExpert.cc:31
void initialize()
Signal the beginning of the event processing.
Definition: MVAExpert.cc:85
void beginRun()
Called once before a new run begins.
Definition: MVAExpert.cc:96
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
Database pointer to the Database representation of the weightfile.
Definition: MVAExpert.cc:49
std::unique_ptr< MVA::Weightfile > getWeightFile()
Get the weight file.
Definition: MVAExpert.cc:144
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
Definition: MVAExpert.cc:52
std::vector< Named< Float_t * > > m_selectedNamedVariables
References to the selected named values from the source variable set.
Definition: MVAExpert.cc:46
Impl(const std::string &identifier, std::vector< Named< Float_t * > > namedVariables)
constructor
Definition: MVAExpert.cc:78
std::unique_ptr< MVA::Dataset > m_dataset
Pointer to the current dataset.
Definition: MVAExpert.cc:55
std::vector< Named< Float_t * > > m_allNamedVariables
References to the all named values from the source variable set.
Definition: MVAExpert.cc:43
MVA::GeneralOptions m_generalOptions
General options.
Definition: MVAExpert.cc:58
double predict()
Get the MVA prediction.
Definition: MVAExpert.cc:155
std::vector< std::string > getVariableNames()
Get predictions for several inputs.
Definition: MVAExpert.cc:185
std::string m_identifier
DB identifier of the expert or file name.
Definition: MVAExpert.cc:61
Class to interact with the MVA package.
Definition: MVAExpert.h:26
void initialize()
Initialise the mva method.
Definition: MVAExpert.cc:208
void beginRun()
Update the mva method to the new run.
Definition: MVAExpert.cc:213
std::unique_ptr< Impl > m_impl
Pointer to implementation hiding the details.
Definition: MVAExpert.h:59
~MVAExpert()
Destructor must be defined in cpp because of PImpl pointer.
MVAExpert(const std::string &identifier, std::vector< Named< Float_t * > > namedVariables)
Construct the Expert with the specified weight folder and the name of the training that was used in t...
double predict()
Evaluate the MVA method and return the MVAOutput.
Definition: MVAExpert.cc:218
std::vector< std::string > getVariableNames()
Get selected variable names.
Definition: MVAExpert.cc:228
Filter based on a mva method.
Definition: MVAFilter.dcl.h:36
A mixin class to attach a name to an object.
Definition: Named.h:23
Abstract base class for different kinds of events.
STL namespace.