8#include <tracking/trackFindingCDC/mva/MVAExpert.h>
11#include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
12#include <mva/interface/Expert.h>
13#include <mva/interface/Weightfile.h>
14#include <framework/database/DBObjPtr.h>
15#include <boost/algorithm/string/predicate.hpp>
29 namespace TrackFindingCDC {
39 std::vector<float>
predict(
float* ,
int ,
int );
68#include <mva/interface/Interface.h>
70#include <framework/utilities/FileSystem.h>
71#include <framework/logging/Logger.h>
76using namespace TrackFindingCDC;
80 : m_allNamedVariables(
std::move(namedVariables))
81 , m_identifier(identifier)
88 using boost::algorithm::ends_with;
89 if (not m_weightfileRepresentation and
90 not(ends_with(m_identifier,
".root") or ends_with(m_identifier,
".xml"))) {
92 m_weightfileRepresentation = std::make_unique<DBWeightFileRepresentation>(m_identifier);
98 std::unique_ptr<MVA::Weightfile> weightfile = getWeightFile();
100 if ((weightfile->getElement<std::string>(
"method") ==
"FastBDT" and
101 (weightfile->getElement<
int>(
"FastBDT_version") == 1 or
102 weightfile->getElement<
int>(
"FastBDT_version") == 2)) or
103 (weightfile->getElement<std::string>(
"method") ==
"Python")) {
105 int nExpectedVars = weightfile->getElement<
int>(
"number_feature_variables");
107 m_selectedNamedVariables.clear();
108 for (
int iVar = 0; iVar < nExpectedVars; ++iVar) {
109 std::string variableElementName =
"variable" + std::to_string(iVar);
110 std::string expectedName = weightfile->getElement<std::string>(variableElementName);
111 auto itNamedVariable = std::find_if(m_allNamedVariables.begin(),
112 m_allNamedVariables.end(),
114 return namedVariable.getName() == expectedName;
117 if (itNamedVariable == m_allNamedVariables.end()) {
118 B2ERROR(
"Variable name " << iVar <<
" mismatch for FastBDT. " <<
119 "Could not find expected variable '" << expectedName <<
"'");
121 m_selectedNamedVariables.push_back(*itNamedVariable);
123 B2ASSERT(
"Number of variables mismatch", nExpectedVars ==
static_cast<int>(m_selectedNamedVariables.size()));
125 B2WARNING(
"Unpacked new kind of classifier. Consider to extend the feature variable check. Identifier name: " << m_identifier
126 <<
"; method name: " << weightfile->getElement<std::string>(
"method"));
127 m_selectedNamedVariables = m_allNamedVariables;
130 std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
132 weightfile->getOptions(m_generalOptions);
133 m_expert = supportedInterfaces[m_generalOptions.m_method]->getExpert();
134 m_expert->load(*weightfile);
136 std::vector<float> dummy;
137 dummy.resize(m_selectedNamedVariables.size(), 0);
138 m_dataset = std::make_unique<MVA::SingleDataset>(m_generalOptions, std::move(dummy), 0);
140 B2ERROR(
"Could not find weight file for identifier " << m_identifier);
146 if (m_weightfileRepresentation) {
147 std::stringstream ss((*m_weightfileRepresentation)->m_data);
158 B2ERROR(
"MVA Expert is not loaded! I will return 0");
163 for (
unsigned int i = 0; i < m_selectedNamedVariables.size(); ++i) {
164 m_dataset->m_input[i] = *m_selectedNamedVariables[i];
166 return m_expert->apply(*m_dataset)[0];
171 std::vector<std::vector<float>> spectators;
172 std::vector<std::vector <float> > data;
174 for (
int iRow = 0; iRow < nRows; iRow += 1) {
175 data[iRow].resize(nFeature);
176 for (
int iFeature = 0; iFeature < nFeature; iFeature += 1) {
177 data[iRow][iFeature] = test_data[nFeature * iRow + iFeature];
182 return m_expert->apply(dataSet);
187 std::vector<std::string> out(m_selectedNamedVariables.size());
188 for (
size_t iName = 0; iName < m_selectedNamedVariables.size(); iName += 1) {
189 out[iName] = m_selectedNamedVariables[iName].getName();
210 return m_impl->initialize();
215 return m_impl->beginRun();
225 return m_impl->predict(test_data, nFeature, nRows);
230 return m_impl->getVariableNames();
Class for accessing objects in the database.
Database representation of a Weightfile object.
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
General options which are shared by all MVA trainings.
Wraps the data of a multiple event into a Dataset.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Implementation of the class to interact with the MVA package.
void initialize()
Signal the beginning of the event processing.
void beginRun()
Called once before a new run begins.
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
Database pointer to the Database representation of the weightfile.
std::unique_ptr< MVA::Weightfile > getWeightFile()
Get the weight file.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
std::vector< Named< Float_t * > > m_selectedNamedVariables
References to the selected named values from the source variable set.
Impl(const std::string &identifier, std::vector< Named< Float_t * > > namedVariables)
constructor
std::unique_ptr< MVA::Dataset > m_dataset
Pointer to the current dataset.
std::vector< Named< Float_t * > > m_allNamedVariables
References to the all named values from the source variable set.
MVA::GeneralOptions m_generalOptions
General options.
double predict()
Get the MVA prediction.
std::vector< std::string > getVariableNames()
Get predictions for several inputs.
std::string m_identifier
DB identifier of the expert or file name.
Class to interact with the MVA package.
void initialize()
Initialise the mva method.
void beginRun()
Update the mva method to the new run.
std::unique_ptr< Impl > m_impl
Pointer to implementation hiding the details.
~MVAExpert()
Destructor must be defined in cpp because of PImpl pointer.
MVAExpert(const std::string &identifier, std::vector< Named< Float_t * > > namedVariables)
Construct the Expert with the specified weight folder and the name of the training that was used in t...
double predict()
Evaluate the MVA method and return the MVAOutput.
std::vector< std::string > getVariableNames()
Get selected variable names.
Filter based on a mva method.
A mixin class to attach a name to an object.
Abstract base class for different kinds of events.