 |
Belle II Software
release-05-01-25
|
10 #include <tracking/trackFindingVXD/mva/MVAExpert.h>
13 #include <mva/interface/Interface.h>
15 #include <framework/utilities/FileSystem.h>
16 #include <framework/logging/Logger.h>
18 #include <boost/algorithm/string/predicate.hpp>
24 : m_allNamedVariables(std::move(namedVariables))
25 , m_identifier(identifier)
32 using boost::algorithm::ends_with;
42 std::unique_ptr<MVA::Weightfile> weightfile =
getWeightFile();
45 if (weightfile->
getElement<std::string>(
"method") ==
"FastBDT" and
46 (weightfile->
getElement<
int>(
"FastBDT_version") == 1 or
47 weightfile->
getElement<
int>(
"FastBDT_version") == 2)) {
49 int nExpectedVars = weightfile->
getElement<
int>(
"number_feature_variables");
52 for (
int iVar = 0; iVar < nExpectedVars; ++iVar) {
53 std::string variableElementName =
"variable" + std::to_string(iVar);
54 std::string expectedName = weightfile->
getElement<std::string>(variableElementName);
59 return namedVariable.getName() == expectedName;
63 B2ERROR(
"Variable name " << iVar <<
" mismatch for FastBDT. " <<
64 "Could not find expected variable '" << expectedName <<
"'");
70 B2WARNING(
"Unpacked new kind of classifier. Consider to extend the feature variable check.");
74 std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
81 std::vector<float> dummy;
83 m_dataset = std::make_unique<MVA::SingleDataset>(generalOptions, std::move(dummy), 0);
85 B2ERROR(
"Could not find weight file for identifier " <<
m_identifier);
92 std::stringstream ss((*m_weightfileRepresentation)->m_data);
103 B2ERROR(
"MVA Expert is not loaded! I will return 0");
std::unique_ptr< MVA::Dataset > m_dataset
Pointer to the current dataset.
MVAExpert(const std::string &identifier, std::vector< Named< float * >> namedVariables)
Construct the Expert with the specified weight folder and the name of the training that was used in t...
std::unique_ptr< MVA::Weightfile > getWeightFile()
Resolves the source of the weight file and unpacks it.
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
std::vector< Named< float * > > m_selectedNamedVariables
References to the selected named values from the source variable set.
void initialize()
Initialise the mva method.
void beginRun()
Update the mva method to the new run.
std::vector< Named< float * > > m_allNamedVariables
References to the named values from the source variable set.
Class for accessing objects in the database.
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
std::string m_method
Name of the MVA method to use.
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
Database pointer to the Database representation of the weightfile.
Abstract base class for different kinds of events.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
std::string m_identifier
DB identifier of the expert or file name.
General options which are shared by all MVA trainings.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
float predict()
Evaluate the MVA method and return the MVAOutput.