8 #include <tracking/trackFindingCDC/mva/MVAExpert.h>
11 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
12 #include <mva/interface/Weightfile.h>
13 #include <mva/interface/Expert.h>
14 #include <framework/database/DBObjPtr.h>
16 #include <boost/algorithm/string/predicate.hpp>
30 namespace TrackFindingCDC {
65 #include <mva/interface/Interface.h>
67 #include <framework/utilities/FileSystem.h>
68 #include <framework/logging/Logger.h>
73 using namespace TrackFindingCDC;
77 : m_allNamedVariables(std::move(namedVariables))
78 , m_identifier(identifier)
85 using boost::algorithm::ends_with;
86 if (not m_weightfileRepresentation and
87 not(ends_with(m_identifier,
".root") or ends_with(m_identifier,
".xml"))) {
89 m_weightfileRepresentation = std::make_unique<DBWeightFileRepresentation>(m_identifier);
95 std::unique_ptr<MVA::Weightfile> weightfile = getWeightFile();
97 if (weightfile->getElement<std::string>(
"method") ==
"FastBDT" and
98 (weightfile->getElement<
int>(
"FastBDT_version") == 1 or
99 weightfile->getElement<
int>(
"FastBDT_version") == 2)) {
101 int nExpectedVars = weightfile->getElement<
int>(
"number_feature_variables");
103 m_selectedNamedVariables.clear();
104 for (
int iVar = 0; iVar < nExpectedVars; ++iVar) {
105 std::string variableElementName =
"variable" + std::to_string(iVar);
106 std::string expectedName = weightfile->getElement<std::string>(variableElementName);
108 auto itNamedVariable = std::find_if(m_allNamedVariables.begin(),
109 m_allNamedVariables.end(),
111 return namedVariable.getName() == expectedName;
114 if (itNamedVariable == m_allNamedVariables.end()) {
115 B2ERROR(
"Variable name " << iVar <<
" mismatch for FastBDT. " <<
116 "Could not find expected variable '" << expectedName <<
"'");
118 m_selectedNamedVariables.push_back(*itNamedVariable);
120 B2ASSERT(
"Number of variables mismatch", nExpectedVars ==
static_cast<int>(m_selectedNamedVariables.size()));
122 B2WARNING(
"Unpacked new kind of classifier. Consider to extend the feature variable check. Identifier name: " << m_identifier
123 <<
"; method name: " << weightfile->getElement<std::string>(
"method"));
124 m_selectedNamedVariables = m_allNamedVariables;
127 std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
130 weightfile->getOptions(generalOptions);
131 m_expert = supportedInterfaces[generalOptions.
m_method]->getExpert();
132 m_expert->
load(*weightfile);
134 std::vector<float> dummy;
135 dummy.resize(m_selectedNamedVariables.size(), 0);
136 m_dataset = std::make_unique<MVA::SingleDataset>(generalOptions, std::move(dummy), 0);
138 B2ERROR(
"Could not find weight file for identifier " << m_identifier);
144 if (m_weightfileRepresentation) {
145 std::stringstream ss((*m_weightfileRepresentation)->m_data);
156 B2ERROR(
"MVA Expert is not loaded! I will return 0");
161 for (
unsigned int i = 0; i < m_selectedNamedVariables.size(); ++i) {
162 m_dataset->m_input[i] = *m_selectedNamedVariables[i];
164 return m_expert->apply(*m_dataset)[0];
178 return m_impl->initialize();
183 return m_impl->beginRun();
Class for accessing objects in the database.
Database representation of a Weightfile object.
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
General options which are shared by all MVA trainings.
std::string m_method
Name of the MVA method to use.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism (used by Weightfile) to load Options from a xml tree.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Implementation of the class to interact with the MVA package.
void initialize()
Signal the beginning of the event processing.
void beginRun()
Called once before a new run begins.
Impl(const std::string &identifier, std::vector< Named< Float_t * >> namedVariables)
constructor
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
Database pointer to the Database representation of the weightfile.
std::unique_ptr< MVA::Weightfile > getWeightFile()
Get the weight file.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
std::vector< Named< Float_t * > > m_selectedNamedVariables
References to the selected named values from the source variable set.
std::unique_ptr< MVA::Dataset > m_dataset
Pointer to the current dataset.
std::vector< Named< Float_t * > > m_allNamedVariables
References to the all named values from the source variable set.
double predict()
Get the MVA prediction.
std::string m_identifier
DB identifier of the expert or file name.
Class to interact with the MVA package.
void initialize()
Initialise the mva method.
MVAExpert(const std::string &identifier, std::vector< Named< Float_t * >> namedVariables)
Construct the Expert with the specified weight folder and the name of the training that was used in t...
void beginRun()
Update the mva method to the new run.
std::unique_ptr< Impl > m_impl
Pointer to implementation hiding the details.
~MVAExpert()
Destructor must be defined in cpp because of PImpl pointer.
double predict()
Evaluate the MVA method and return the MVAOutput.
Filter based on a mva method.
Abstract base class for different kinds of events.