Belle II Software development
MVAExpert.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#include <tracking/trackingUtilities/mva/MVAExpert.h>
9
11#include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
12#include <mva/interface/Expert.h>
13#include <mva/interface/Weightfile.h>
14#include <framework/database/DBObjPtr.h>
15#include <boost/algorithm/string/predicate.hpp>
16
17namespace Belle2 {
23 namespace MVA {
24 class Expert;
25 class SingleDataset;
26 class Weightfile;
27 }
28
29 namespace TrackingUtilities {
32
33 public:
34 Impl(const std::string& identifier, std::vector<Named<Float_t*>> namedVariables);
35 void initialize();
36 void beginRun();
37 std::unique_ptr<MVA::Weightfile> getWeightFile();
38 double predict();
39 std::vector<float> predict(float* /* test_data */, int /* nFeature */, int /* nRows */);
40 std::vector<std::string> getVariableNames();
41 private:
43 std::vector<Named<Float_t*> > m_allNamedVariables;
44
46 std::vector<Named<Float_t*> > m_selectedNamedVariables;
47
49 std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile> > m_weightfileRepresentation;
50
52 std::unique_ptr<MVA::Expert> m_expert;
53
55 std::unique_ptr<MVA::Dataset> m_dataset;
56
59
61 std::string m_identifier;
62 };
63 }
65}
66
68#include <mva/interface/Interface.h>
69
70#include <framework/utilities/FileSystem.h>
71#include <framework/logging/Logger.h>
72
73#include <algorithm>
74
75using namespace Belle2;
76using namespace TrackingUtilities;
77
78MVAExpert::Impl::Impl(const std::string& identifier,
79 std::vector<Named<Float_t*> > namedVariables)
80 : m_allNamedVariables(std::move(namedVariables))
81 , m_identifier(identifier)
82{
83}
84
86{
88 using boost::algorithm::ends_with;
90 not(ends_with(m_identifier, ".root") or ends_with(m_identifier, ".xml"))) {
91 using DBWeightFileRepresentation = DBObjPtr<DatabaseRepresentationOfWeightfile>;
92 m_weightfileRepresentation = std::make_unique<DBWeightFileRepresentation>(m_identifier);
93 }
94 if ((not m_weightfileRepresentation) or (not m_weightfileRepresentation->isValid())) {
95 B2FATAL("No weight file could be loaded in tracking/trackingUtilities/mva/MVAExpert.");
96 }
97}
98
100{
101 std::unique_ptr<MVA::Weightfile> weightfile = getWeightFile();
102 if (weightfile) {
103 if ((weightfile->getElement<std::string>("method") == "FastBDT" and
104 (weightfile->getElement<int>("FastBDT_version") == 1 or
105 weightfile->getElement<int>("FastBDT_version") == 2)) or
106 (weightfile->getElement<std::string>("method") == "Python")) {
107
108 int nExpectedVars = weightfile->getElement<int>("number_feature_variables");
109
111 for (int iVar = 0; iVar < nExpectedVars; ++iVar) {
112 std::string variableElementName = "variable" + std::to_string(iVar);
113 std::string expectedName = weightfile->getElement<std::string>(variableElementName);
114 auto itNamedVariable = std::find_if(m_allNamedVariables.begin(),
116 [expectedName](const Named<Float_t*>& namedVariable) {
117 return namedVariable.getName() == expectedName;
118 });
119
120 if (itNamedVariable == m_allNamedVariables.end()) {
121 B2ERROR("Variable name " << iVar << " mismatch for FastBDT. " <<
122 "Could not find expected variable '" << expectedName << "'");
123 }
124 m_selectedNamedVariables.push_back(*itNamedVariable);
125 }
126 B2ASSERT("Number of variables mismatch", nExpectedVars == static_cast<int>(m_selectedNamedVariables.size()));
127 } else {
128 B2WARNING("Unpacked new kind of classifier. Consider to extend the feature variable check. Identifier name: " << m_identifier
129 << "; method name: " << weightfile->getElement<std::string>("method"));
131 }
132
133 std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
135 weightfile->getOptions(m_generalOptions);
136 m_expert = supportedInterfaces[m_generalOptions.m_method]->getExpert();
137 m_expert->load(*weightfile);
138
139 std::vector<float> dummy;
140 dummy.resize(m_selectedNamedVariables.size(), 0);
141 m_dataset = std::make_unique<MVA::SingleDataset>(m_generalOptions, std::move(dummy), 0);
142 } else {
143 B2ERROR("Could not find weight file for identifier " << m_identifier);
144 }
145}
146
147std::unique_ptr<MVA::Weightfile> MVAExpert::Impl::getWeightFile()
148{
150 std::stringstream ss((*m_weightfileRepresentation)->m_data);
151 return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromStream(ss));
152 } else {
153 std::string weightFilePath = FileSystem::findFile(m_identifier);
154 return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromFile(weightFilePath));
155 }
156}
157
159{
160 if (not m_expert) {
161 B2ERROR("MVA Expert is not loaded! I will return 0");
162 return NAN;
163 }
164
165 // Transfer the extracted values to the data set were the expert can find them
166 for (unsigned int i = 0; i < m_selectedNamedVariables.size(); ++i) {
167 m_dataset->m_input[i] = *m_selectedNamedVariables[i];
168 }
169 return m_expert->apply(*m_dataset)[0];
170}
171
172std::vector<float> MVAExpert::Impl::predict(float* test_data, int nFeature, int nRows)
173{
174 std::vector<std::vector<float>> spectators;
175 std::vector<std::vector <float> > data;
176 data.resize(nRows);
177 for (int iRow = 0; iRow < nRows; iRow += 1) {
178 data[iRow].resize(nFeature);
179 for (int iFeature = 0; iFeature < nFeature; iFeature += 1) {
180 data[iRow][iFeature] = test_data[nFeature * iRow + iFeature];
181 }
182 }
183
184 MVA::MultiDataset dataSet(m_generalOptions, data, spectators);
185 return m_expert->apply(dataSet);
186}
187
188std::vector<std::string> MVAExpert::Impl::getVariableNames()
189{
190 std::vector<std::string> out(m_selectedNamedVariables.size());
191 for (size_t iName = 0; iName < m_selectedNamedVariables.size(); iName += 1) {
192 out[iName] = m_selectedNamedVariables[iName].getName();
193 }
194 return out;
195}
196
197
199// Silence Doxygen which is complaining that "no matching class member found for"
200// But there should be a better way that I just don't know of / find
202MVAExpert::MVAExpert(const std::string& identifier,
203 std::vector<Named<Float_t*> > namedVariables)
204 : m_impl(std::make_unique<MVAExpert::Impl>(identifier, std::move(namedVariables)))
206{
207}
208
209MVAExpert::~MVAExpert() = default;
210
212{
213 return m_impl->initialize();
214}
215
217{
218 return m_impl->beginRun();
219}
220
222{
223 return m_impl->predict();
224}
225
226std::vector<float> MVAExpert::predict(float* test_data, int nFeature, int nRows)
227{
228 return m_impl->predict(test_data, nFeature, nRows);
229}
230
231std::vector<std::string> MVAExpert::getVariableNames()
232{
233 return m_impl->getVariableNames();
234}
235
Class for accessing objects in the database.
Definition DBObjPtr.h:21
Database representation of a Weightfile object.
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
static void initSupportedInterfaces()
Static function which initializes all supported interfaces, has to be called once before getSupported...
Definition Interface.cc:46
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition Interface.h:53
Abstract base class of all Expert Each MVA library has its own implementation of this class,...
Definition Expert.h:31
General options which are shared by all MVA trainings.
Definition Options.h:62
Wraps the data of a multiple event into a Dataset.
Definition Dataset.h:186
Wraps the data of a single event into a Dataset.
Definition Dataset.h:135
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Implementation of the class to interact with the MVA package.
Definition MVAExpert.cc:31
void initialize()
Signal the beginning of the event processing.
Definition MVAExpert.cc:85
void beginRun()
Called once before a new run begins.
Definition MVAExpert.cc:99
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
Database pointer to the Database representation of the weightfile.
Definition MVAExpert.cc:49
std::unique_ptr< MVA::Weightfile > getWeightFile()
Get the weight file.
Definition MVAExpert.cc:147
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
Definition MVAExpert.cc:52
std::vector< Named< Float_t * > > m_selectedNamedVariables
References to the selected named values from the source variable set.
Definition MVAExpert.cc:46
Impl(const std::string &identifier, std::vector< Named< Float_t * > > namedVariables)
constructor
Definition MVAExpert.cc:78
std::unique_ptr< MVA::Dataset > m_dataset
Pointer to the current dataset.
Definition MVAExpert.cc:55
std::vector< Named< Float_t * > > m_allNamedVariables
References to the all named values from the source variable set.
Definition MVAExpert.cc:43
MVA::GeneralOptions m_generalOptions
General options.
Definition MVAExpert.cc:58
double predict()
Get the MVA prediction.
Definition MVAExpert.cc:158
std::vector< std::string > getVariableNames()
Get selected variable names.
Definition MVAExpert.cc:188
std::string m_identifier
DB identifier of the expert or file name.
Definition MVAExpert.cc:61
void initialize()
Initialise the mva method.
Definition MVAExpert.cc:211
void beginRun()
Update the mva method to the new run.
Definition MVAExpert.cc:216
std::unique_ptr< Impl > m_impl
Pointer to implementation hiding the details.
Definition MVAExpert.h:59
~MVAExpert()
Destructor must be defined in cpp because of PImpl pointer.
MVAExpert(const std::string &identifier, std::vector< Named< Float_t * > > namedVariables)
Construct the Expert with the specified weight folder and the name of the training that was used in t...
double predict()
Evaluate the MVA method and return the MVAOutput.
Definition MVAExpert.cc:221
std::vector< std::string > getVariableNames()
Get selected variable names.
Definition MVAExpert.cc:231
A mixin class to attach a name to an object.
Definition Named.h:23
Abstract base class for different kinds of events.
STL namespace.