Belle II Software development
MVAExpert.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#include <tracking/trackFindingVXD/mva/MVAExpert.h>
9
10
11#include <mva/interface/Interface.h>
12
13#include <framework/utilities/FileSystem.h>
14#include <framework/logging/Logger.h>
15
16#include <boost/algorithm/string/predicate.hpp>
17
18using namespace Belle2;
19
20MVAExpert::MVAExpert(const std::string& identifier,
21 std::vector<Named<float*>> namedVariables)
22 : m_allNamedVariables(std::move(namedVariables))
23 , m_identifier(identifier)
24{
25}
26
28{
30 using boost::algorithm::ends_with;
32 not(ends_with(m_identifier, ".root") or ends_with(m_identifier, ".xml"))) {
33 using DBWeightFileRepresentation = DBObjPtr<DatabaseRepresentationOfWeightfile>;
34 m_weightfileRepresentation = std::make_unique<DBWeightFileRepresentation>(m_identifier);
35 }
36}
37
39{
40 std::unique_ptr<MVA::Weightfile> weightfile = getWeightFile();
41 if (weightfile) {
42 // FastBDT_version refers to the weightfile version, only FastBDT_VERSION_MAJOR >= 5 can handle FastBDT_version==2
43 if (weightfile->getElement<std::string>("method") == "FastBDT" and
44 (weightfile->getElement<int>("FastBDT_version") == 1 or
45 weightfile->getElement<int>("FastBDT_version") == 2)) {
46
47 int nExpectedVars = weightfile->getElement<int>("number_feature_variables");
48
50 for (int iVar = 0; iVar < nExpectedVars; ++iVar) {
51 std::string variableElementName = "variable" + std::to_string(iVar);
52 std::string expectedName = weightfile->getElement<std::string>(variableElementName);
53
54 auto itNamedVariable = std::find_if(m_allNamedVariables.begin(),
56 [expectedName](const Named<Float_t*>& namedVariable) {
57 return namedVariable.getName() == expectedName;
58 });
59
60 if (itNamedVariable == m_allNamedVariables.end()) {
61 B2ERROR("Variable name " << iVar << " mismatch for FastBDT. " <<
62 "Could not find expected variable '" << expectedName << "'");
63 }
64 m_selectedNamedVariables.push_back(*itNamedVariable);
65 }
66 B2ASSERT("Number of variables mismatch", nExpectedVars == static_cast<int>(m_selectedNamedVariables.size()));
67 } else {
68 B2WARNING("Unpacked new kind of classifier. Consider to extend the feature variable check.");
70 }
71
72 std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
74 MVA::GeneralOptions generalOptions;
75 weightfile->getOptions(generalOptions);
76 m_expert = supportedInterfaces[generalOptions.m_method]->getExpert();
77 m_expert->load(*weightfile);
78
79 std::vector<float> dummy;
80 dummy.resize(m_selectedNamedVariables.size(), 0);
81 m_dataset = std::make_unique<MVA::SingleDataset>(generalOptions, std::move(dummy), 0);
82 } else {
83 B2ERROR("Could not find weight file for identifier " << m_identifier);
84 }
85}
86
87std::unique_ptr<MVA::Weightfile> MVAExpert::getWeightFile()
88{
90 std::stringstream ss((*m_weightfileRepresentation)->m_data);
91 return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromStream(ss));
92 } else {
93 std::string weightFilePath = FileSystem::findFile(m_identifier);
94 return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromFile(weightFilePath));
95 }
96}
97
99{
100 if (not m_expert) {
101 B2ERROR("MVA Expert is not loaded! I will return 0");
102 return 0;
103 }
104
105 // Transfer the extracted values to the data set were the expert can find them
106 for (unsigned int i = 0; i < m_selectedNamedVariables.size(); ++i) {
107 m_dataset->m_input[i] = *(m_selectedNamedVariables[i].getValue());
108 }
109 return m_expert->apply(*m_dataset)[0];
110}
Class for accessing objects in the database.
Definition: DBObjPtr.h:21
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:151
void initialize()
Initialise the mva method.
Definition: MVAExpert.cc:27
void beginRun()
Update the mva method to the new run.
Definition: MVAExpert.cc:38
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
Database pointer to the Database representation of the weightfile.
Definition: MVAExpert.h:63
MVAExpert(const std::string &identifier, std::vector< Named< float * > > namedVariables)
Construct the Expert with the specified weight folder and the name of the training that was used in t...
Definition: MVAExpert.cc:20
std::unique_ptr< MVA::Weightfile > getWeightFile()
Resolves the source of the weight file and unpacks it.
Definition: MVAExpert.cc:87
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
Definition: MVAExpert.h:66
std::vector< Named< float * > > m_selectedNamedVariables
References to the selected named values from the source variable set.
Definition: MVAExpert.h:60
std::unique_ptr< MVA::Dataset > m_dataset
Pointer to the current dataset.
Definition: MVAExpert.h:69
float predict()
Evaluate the MVA method and return the MVAOutput.
Definition: MVAExpert.cc:98
std::vector< Named< float * > > m_allNamedVariables
References to the named values from the source variable set.
Definition: MVAExpert.h:57
std::string m_identifier
DB identifier of the expert or file name.
Definition: MVAExpert.h:72
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::string m_method
Name of the MVA method to use.
Definition: Options.h:82
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
A mixin class to attach a name to an object. Based on class with same name in CDC package.
Definition: Named.h:21
Abstract base class for different kinds of events.
STL namespace.