Belle II Software  release-06-01-15
MVAExpert.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 #include <tracking/trackFindingCDC/mva/MVAExpert.h>
9 
11 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
12 #include <mva/interface/Weightfile.h>
13 #include <mva/interface/Expert.h>
14 #include <framework/database/DBObjPtr.h>
15 
16 #include <boost/algorithm/string/predicate.hpp>
17 
18 namespace Belle2 {
24  namespace MVA {
25  class Expert;
26  class SingleDataset;
27  class Weightfile;
28  }
29 
30  namespace TrackFindingCDC {
33 
34  public:
35  Impl(const std::string& identifier, std::vector<Named<Float_t*>> namedVariables);
36  void initialize();
37  void beginRun();
38  std::unique_ptr<MVA::Weightfile> getWeightFile();
39  double predict();
41  private:
43  std::vector<Named<Float_t*> > m_allNamedVariables;
44 
46  std::vector<Named<Float_t*> > m_selectedNamedVariables;
47 
49  std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile> > m_weightfileRepresentation;
50 
52  std::unique_ptr<MVA::Expert> m_expert;
53 
55  std::unique_ptr<MVA::Dataset> m_dataset;
56 
58  std::string m_identifier;
59  };
60  }
62 }
63 
65 #include <mva/interface/Interface.h>
66 
67 #include <framework/utilities/FileSystem.h>
68 #include <framework/logging/Logger.h>
69 
70 #include <algorithm>
71 
72 using namespace Belle2;
73 using namespace TrackFindingCDC;
74 
75 MVAExpert::Impl::Impl(const std::string& identifier,
76  std::vector<Named<Float_t*> > namedVariables)
77  : m_allNamedVariables(std::move(namedVariables))
78  , m_identifier(identifier)
79 {
80 }
81 
83 {
85  using boost::algorithm::ends_with;
86  if (not m_weightfileRepresentation and
87  not(ends_with(m_identifier, ".root") or ends_with(m_identifier, ".xml"))) {
88  using DBWeightFileRepresentation = DBObjPtr<DatabaseRepresentationOfWeightfile>;
89  m_weightfileRepresentation = std::make_unique<DBWeightFileRepresentation>(m_identifier);
90  }
91 }
92 
94 {
95  std::unique_ptr<MVA::Weightfile> weightfile = getWeightFile();
96  if (weightfile) {
97  if (weightfile->getElement<std::string>("method") == "FastBDT" and
98  (weightfile->getElement<int>("FastBDT_version") == 1 or
99  weightfile->getElement<int>("FastBDT_version") == 2)) {
100 
101  int nExpectedVars = weightfile->getElement<int>("number_feature_variables");
102 
103  m_selectedNamedVariables.clear();
104  for (int iVar = 0; iVar < nExpectedVars; ++iVar) {
105  std::string variableElementName = "variable" + std::to_string(iVar);
106  std::string expectedName = weightfile->getElement<std::string>(variableElementName);
107 
108  auto itNamedVariable = std::find_if(m_allNamedVariables.begin(),
109  m_allNamedVariables.end(),
110  [expectedName](const Named<Float_t*>& namedVariable) {
111  return namedVariable.getName() == expectedName;
112  });
113 
114  if (itNamedVariable == m_allNamedVariables.end()) {
115  B2ERROR("Variable name " << iVar << " mismatch for FastBDT. " <<
116  "Could not find expected variable '" << expectedName << "'");
117  }
118  m_selectedNamedVariables.push_back(*itNamedVariable);
119  }
120  B2ASSERT("Number of variables mismatch", nExpectedVars == static_cast<int>(m_selectedNamedVariables.size()));
121  } else {
122  B2WARNING("Unpacked new kind of classifier. Consider to extend the feature variable check. Identifier name: " << m_identifier
123  << "; method name: " << weightfile->getElement<std::string>("method"));
124  m_selectedNamedVariables = m_allNamedVariables;
125  }
126 
127  std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
129  MVA::GeneralOptions generalOptions;
130  weightfile->getOptions(generalOptions);
131  m_expert = supportedInterfaces[generalOptions.m_method]->getExpert();
132  m_expert->load(*weightfile);
133 
134  std::vector<float> dummy;
135  dummy.resize(m_selectedNamedVariables.size(), 0);
136  m_dataset = std::make_unique<MVA::SingleDataset>(generalOptions, std::move(dummy), 0);
137  } else {
138  B2ERROR("Could not find weight file for identifier " << m_identifier);
139  }
140 }
141 
142 std::unique_ptr<MVA::Weightfile> MVAExpert::Impl::getWeightFile()
143 {
144  if (m_weightfileRepresentation) {
145  std::stringstream ss((*m_weightfileRepresentation)->m_data);
146  return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromStream(ss));
147  } else {
148  std::string weightFilePath = FileSystem::findFile(m_identifier);
149  return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromFile(weightFilePath));
150  }
151 }
152 
154 {
155  if (not m_expert) {
156  B2ERROR("MVA Expert is not loaded! I will return 0");
157  return NAN;
158  }
159 
160  // Transfer the extracted values to the data set were the expert can find them
161  for (unsigned int i = 0; i < m_selectedNamedVariables.size(); ++i) {
162  m_dataset->m_input[i] = *m_selectedNamedVariables[i];
163  }
164  return m_expert->apply(*m_dataset)[0];
165 }
166 
168 MVAExpert::MVAExpert(const std::string& identifier,
169  std::vector<Named<Float_t*> > namedVariables)
170  : m_impl(std::make_unique<MVAExpert::Impl>(identifier, std::move(namedVariables)))
171 {
172 }
173 
174 MVAExpert::~MVAExpert() = default;
175 
177 {
178  return m_impl->initialize();
179 }
180 
182 {
183  return m_impl->beginRun();
184 }
185 
187 {
188  return m_impl->predict();
189 }
Class for accessing objects in the database.
Definition: DBObjPtr.h:21
Database representation of a Weightfile object.
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:145
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::string m_method
Name of the MVA method to use.
Definition: Options.h:82
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism (used by Weightfile) to load Options from a xml tree.
Definition: Options.cc:42
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:250
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:205
Implementation of the class to interact with the MVA package.
Definition: MVAExpert.cc:32
void initialize()
Signal the beginning of the event processing.
Definition: MVAExpert.cc:82
void beginRun()
Called once before a new run begins.
Definition: MVAExpert.cc:93
Impl(const std::string &identifier, std::vector< Named< Float_t * >> namedVariables)
constructor
Definition: MVAExpert.cc:75
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
Database pointer to the Database representation of the weightfile.
Definition: MVAExpert.cc:49
std::unique_ptr< MVA::Weightfile > getWeightFile()
Get the weight file.
Definition: MVAExpert.cc:142
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
Definition: MVAExpert.cc:52
std::vector< Named< Float_t * > > m_selectedNamedVariables
References to the selected named values from the source variable set.
Definition: MVAExpert.cc:46
std::unique_ptr< MVA::Dataset > m_dataset
Pointer to the current dataset.
Definition: MVAExpert.cc:55
std::vector< Named< Float_t * > > m_allNamedVariables
References to the all named values from the source variable set.
Definition: MVAExpert.cc:43
double predict()
Get the MVA prediction.
Definition: MVAExpert.cc:153
std::string m_identifier
DB identifier of the expert or file name.
Definition: MVAExpert.cc:58
Class to interact with the MVA package.
Definition: MVAExpert.h:26
void initialize()
Initialise the mva method.
Definition: MVAExpert.cc:176
MVAExpert(const std::string &identifier, std::vector< Named< Float_t * >> namedVariables)
Construct the Expert with the specified weight folder and the name of the training that was used in t...
void beginRun()
Update the mva method to the new run.
Definition: MVAExpert.cc:181
std::unique_ptr< Impl > m_impl
Pointer to implementation hiding the details.
Definition: MVAExpert.h:51
~MVAExpert()
Destructor must be defined in cpp because of PImpl pointer.
double predict()
Evaluate the MVA method and return the MVAOutput.
Definition: MVAExpert.cc:186
Filter based on a mva method.
Definition: MVAFilter.dcl.h:33
Abstract base class for different kinds of events.