Belle II Software development
MVAExpert::Impl Class Reference

Implementation of the class to interact with the MVA package. More...

Public Member Functions

 Impl (const std::string &identifier, std::vector< Named< Float_t * > > namedVariables)
 constructor
 
void initialize ()
 Signal the beginning of the event processing.
 
void beginRun ()
 Called once before a new run begins.
 
std::unique_ptr< MVA::WeightfilegetWeightFile ()
 Get the weight file.
 
double predict ()
 Get the MVA prediction.
 
std::vector< float > predict (float *, int, int)
 Get predictions for several inputs.
 
std::vector< std::string > getVariableNames ()
 Get selected variable names.
 

Private Attributes

std::vector< Named< Float_t * > > m_allNamedVariables
 References to the all named values from the source variable set.
 
std::vector< Named< Float_t * > > m_selectedNamedVariables
 References to the selected named values from the source variable set.
 
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
 Database pointer to the Database representation of the weightfile.
 
std::unique_ptr< MVA::Expertm_expert
 Pointer to the current MVA Expert.
 
std::unique_ptr< MVA::Datasetm_dataset
 Pointer to the current dataset.
 
MVA::GeneralOptions m_generalOptions
 General options.
 
std::string m_identifier
 DB identifier of the expert or file name.
 

Detailed Description

Implementation of the class to interact with the MVA package.

Definition at line 30 of file MVAExpert.cc.

Constructor & Destructor Documentation

◆ Impl()

Impl ( const std::string & identifier,
std::vector< Named< Float_t * > > namedVariables )

constructor

Definition at line 77 of file MVAExpert.cc.

79 : m_allNamedVariables(std::move(namedVariables))
80 , m_identifier(identifier)
81{
82}
std::vector< Named< Float_t * > > m_allNamedVariables
References to the all named values from the source variable set.
Definition MVAExpert.cc:42
std::string m_identifier
DB identifier of the expert or file name.
Definition MVAExpert.cc:60

Member Function Documentation

◆ beginRun()

void beginRun ( )

Called once before a new run begins.

Definition at line 97 of file MVAExpert.cc.

98{
99 std::unique_ptr<MVA::Weightfile> weightfile = getWeightFile();
100 if (weightfile) {
101 if ((weightfile->getElement<std::string>("method") == "FastBDT" and
102 (weightfile->getElement<int>("FastBDT_version") == 1 or
103 weightfile->getElement<int>("FastBDT_version") == 2)) or
104 (weightfile->getElement<std::string>("method") == "Python")) {
105
106 int nExpectedVars = weightfile->getElement<int>("number_feature_variables");
107
109 for (int iVar = 0; iVar < nExpectedVars; ++iVar) {
110 std::string variableElementName = "variable" + std::to_string(iVar);
111 std::string expectedName = weightfile->getElement<std::string>(variableElementName);
112 auto itNamedVariable = std::find_if(m_allNamedVariables.begin(),
114 [expectedName](const Named<Float_t*>& namedVariable) {
115 return namedVariable.getName() == expectedName;
116 });
117
118 if (itNamedVariable == m_allNamedVariables.end()) {
119 B2ERROR("Variable name " << iVar << " mismatch for FastBDT. " <<
120 "Could not find expected variable '" << expectedName << "'");
121 }
122 m_selectedNamedVariables.push_back(*itNamedVariable);
123 }
124 B2ASSERT("Number of variables mismatch", nExpectedVars == static_cast<int>(m_selectedNamedVariables.size()));
125 } else {
126 B2WARNING("Unpacked new kind of classifier. Consider to extend the feature variable check. Identifier name: " << m_identifier
127 << "; method name: " << weightfile->getElement<std::string>("method"));
129 }
130
131 std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
133 weightfile->getOptions(m_generalOptions);
134 m_expert = supportedInterfaces[m_generalOptions.m_method]->getExpert();
135 m_expert->load(*weightfile);
136
137 std::vector<float> dummy;
138 dummy.resize(m_selectedNamedVariables.size(), 0);
139 m_dataset = std::make_unique<MVA::SingleDataset>(m_generalOptions, std::move(dummy), 0);
140 } else {
141 B2ERROR("Could not find weight file for identifier " << m_identifier);
142 }
143}
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition Interface.h:53
std::unique_ptr< MVA::Weightfile > getWeightFile()
Get the weight file.
Definition MVAExpert.cc:145
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
Definition MVAExpert.cc:51
std::vector< Named< Float_t * > > m_selectedNamedVariables
References to the selected named values from the source variable set.
Definition MVAExpert.cc:45
std::unique_ptr< MVA::Dataset > m_dataset
Pointer to the current dataset.
Definition MVAExpert.cc:54
MVA::GeneralOptions m_generalOptions
General options.
Definition MVAExpert.cc:57

◆ getVariableNames()

std::vector< std::string > getVariableNames ( )

Get selected variable names.

Definition at line 186 of file MVAExpert.cc.

187{
188 std::vector<std::string> out(m_selectedNamedVariables.size());
189 for (size_t iName = 0; iName < m_selectedNamedVariables.size(); iName += 1) {
190 out[iName] = m_selectedNamedVariables[iName].getName();
191 }
192 return out;
193}

◆ getWeightFile()

std::unique_ptr< MVA::Weightfile > getWeightFile ( )

Get the weight file.

Definition at line 145 of file MVAExpert.cc.

146{
148 std::stringstream ss((*m_weightfileRepresentation)->m_data);
149 return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromStream(ss));
150 } else {
151 std::string weightFilePath = FileSystem::findFile(m_identifier);
152 return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromFile(weightFilePath));
153 }
154}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
Database pointer to the Database representation of the weightfile.
Definition MVAExpert.cc:48

◆ initialize()

void initialize ( )

Signal the beginning of the event processing.

Definition at line 84 of file MVAExpert.cc.

85{
88 not(m_identifier.ends_with(".root") or m_identifier.ends_with(".xml"))) {
89 using DBWeightFileRepresentation = DBObjPtr<DatabaseRepresentationOfWeightfile>;
90 m_weightfileRepresentation = std::make_unique<DBWeightFileRepresentation>(m_identifier);
91 }
92 if ((not m_weightfileRepresentation) or (not m_weightfileRepresentation->isValid())) {
93 B2FATAL("No weight file could be loaded in tracking/trackingUtilities/mva/MVAExpert.");
94 }
95}
static void initSupportedInterfaces()
Static function which initializes all supported interfaces, has to be called once before getSupported...
Definition Interface.cc:46

◆ predict() [1/2]

double predict ( )

Get the MVA prediction.

Definition at line 156 of file MVAExpert.cc.

157{
158 if (not m_expert) {
159 B2ERROR("MVA Expert is not loaded! I will return 0");
160 return NAN;
161 }
162
163 // Transfer the extracted values to the data set were the expert can find them
164 for (unsigned int i = 0; i < m_selectedNamedVariables.size(); ++i) {
165 m_dataset->m_input[i] = *m_selectedNamedVariables[i];
166 }
167 return m_expert->apply(*m_dataset)[0];
168}

◆ predict() [2/2]

std::vector< float > predict ( float * test_data,
int nFeature,
int nRows )

Get predictions for several inputs.

Definition at line 170 of file MVAExpert.cc.

171{
172 std::vector<std::vector<float>> spectators;
173 std::vector<std::vector <float> > data;
174 data.resize(nRows);
175 for (int iRow = 0; iRow < nRows; iRow += 1) {
176 data[iRow].resize(nFeature);
177 for (int iFeature = 0; iFeature < nFeature; iFeature += 1) {
178 data[iRow][iFeature] = test_data[nFeature * iRow + iFeature];
179 }
180 }
181
182 MVA::MultiDataset dataSet(m_generalOptions, data, spectators);
183 return m_expert->apply(dataSet);
184}

Member Data Documentation

◆ m_allNamedVariables

std::vector<Named<Float_t*> > m_allNamedVariables
private

References to the all named values from the source variable set.

Definition at line 42 of file MVAExpert.cc.

◆ m_dataset

std::unique_ptr<MVA::Dataset> m_dataset
private

Pointer to the current dataset.

Definition at line 54 of file MVAExpert.cc.

◆ m_expert

std::unique_ptr<MVA::Expert> m_expert
private

Pointer to the current MVA Expert.

Definition at line 51 of file MVAExpert.cc.

◆ m_generalOptions

MVA::GeneralOptions m_generalOptions
private

General options.

Definition at line 57 of file MVAExpert.cc.

◆ m_identifier

std::string m_identifier
private

DB identifier of the expert or file name.

Definition at line 60 of file MVAExpert.cc.

◆ m_selectedNamedVariables

std::vector<Named<Float_t*> > m_selectedNamedVariables
private

References to the selected named values from the source variable set.

Definition at line 45 of file MVAExpert.cc.

◆ m_weightfileRepresentation

std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile> > m_weightfileRepresentation
private

Database pointer to the Database representation of the weightfile.

Definition at line 48 of file MVAExpert.cc.


The documentation for this class was generated from the following file: