10 #include <tracking/trackFindingCDC/mva/MVAExpert.h>
13 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
14 #include <mva/interface/Weightfile.h>
15 #include <mva/interface/Expert.h>
16 #include <framework/database/DBObjPtr.h>
18 #include <boost/algorithm/string/predicate.hpp>
32 namespace TrackFindingCDC {
34 class MVAExpert::Impl {
54 std::unique_ptr<MVA::Expert>
m_expert;
67 #include <mva/interface/Interface.h>
69 #include <framework/utilities/FileSystem.h>
70 #include <framework/logging/Logger.h>
75 using namespace TrackFindingCDC;
79 : m_allNamedVariables(std::move(namedVariables))
80 , m_identifier(identifier)
87 using boost::algorithm::ends_with;
88 if (not m_weightfileRepresentation and
89 not(ends_with(m_identifier,
".root") or ends_with(m_identifier,
".xml"))) {
91 m_weightfileRepresentation = std::make_unique<DBWeightFileRepresentation>(m_identifier);
97 std::unique_ptr<MVA::Weightfile> weightfile = getWeightFile();
99 if (weightfile->
getElement<std::string>(
"method") ==
"FastBDT" and
100 (weightfile->
getElement<
int>(
"FastBDT_version") == 1 or
101 weightfile->
getElement<
int>(
"FastBDT_version") == 2)) {
103 int nExpectedVars = weightfile->
getElement<
int>(
"number_feature_variables");
105 m_selectedNamedVariables.clear();
106 for (
int iVar = 0; iVar < nExpectedVars; ++iVar) {
107 std::string variableElementName =
"variable" + std::to_string(iVar);
108 std::string expectedName = weightfile->
getElement<std::string>(variableElementName);
110 auto itNamedVariable = std::find_if(m_allNamedVariables.begin(),
111 m_allNamedVariables.end(),
113 return namedVariable.getName() == expectedName;
116 if (itNamedVariable == m_allNamedVariables.end()) {
117 B2ERROR(
"Variable name " << iVar <<
" mismatch for FastBDT. " <<
118 "Could not find expected variable '" << expectedName <<
"'");
120 m_selectedNamedVariables.push_back(*itNamedVariable);
122 B2ASSERT(
"Number of variables mismatch", nExpectedVars ==
static_cast<int>(m_selectedNamedVariables.size()));
124 B2WARNING(
"Unpacked new kind of classifier. Consider to extend the feature variable check. Identifier name: " << m_identifier
125 <<
"; method name: " << weightfile->
getElement<std::string>(
"method"));
126 m_selectedNamedVariables = m_allNamedVariables;
129 std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
133 m_expert = supportedInterfaces[generalOptions.
m_method]->getExpert();
134 m_expert->load(*weightfile);
136 std::vector<float> dummy;
137 dummy.resize(m_selectedNamedVariables.size(), 0);
138 m_dataset = std::make_unique<MVA::SingleDataset>(generalOptions, std::move(dummy), 0);
140 B2ERROR(
"Could not find weight file for identifier " << m_identifier);
146 if (m_weightfileRepresentation) {
147 std::stringstream ss((*m_weightfileRepresentation)->m_data);
158 B2ERROR(
"MVA Expert is not loaded! I will return 0");
163 for (
unsigned int i = 0; i < m_selectedNamedVariables.size(); ++i) {
164 m_dataset->m_input[i] = *m_selectedNamedVariables[i];
166 return m_expert->apply(*m_dataset)[0];
180 return m_impl->initialize();
185 return m_impl->beginRun();