Belle II Software  release-05-01-25
MVAExpert.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Oliver Frost *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 #include <tracking/trackFindingCDC/mva/MVAExpert.h>
11 
13 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
14 #include <mva/interface/Weightfile.h>
15 #include <mva/interface/Expert.h>
16 #include <framework/database/DBObjPtr.h>
17 
18 #include <boost/algorithm/string/predicate.hpp>
19 
20 namespace Belle2 {
26  namespace MVA {
27  class Expert;
28  class SingleDataset;
29  class Weightfile;
30  }
31 
32  namespace TrackFindingCDC {
34  class MVAExpert::Impl {
35 
36  public:
37  Impl(const std::string& identifier, std::vector<Named<Float_t*>> namedVariables);
38  void initialize();
39  void beginRun();
40  std::unique_ptr<MVA::Weightfile> getWeightFile();
41  double predict();
43  private:
45  std::vector<Named<Float_t*> > m_allNamedVariables;
46 
48  std::vector<Named<Float_t*> > m_selectedNamedVariables;
49 
51  std::unique_ptr<DBObjPtr<DatabaseRepresentationOfWeightfile> > m_weightfileRepresentation;
52 
54  std::unique_ptr<MVA::Expert> m_expert;
55 
57  std::unique_ptr<MVA::Dataset> m_dataset;
58 
60  std::string m_identifier;
61  };
62  }
64 }
65 
67 #include <mva/interface/Interface.h>
68 
69 #include <framework/utilities/FileSystem.h>
70 #include <framework/logging/Logger.h>
71 
72 #include <algorithm>
73 
74 using namespace Belle2;
75 using namespace TrackFindingCDC;
76 
77 MVAExpert::Impl::Impl(const std::string& identifier,
78  std::vector<Named<Float_t*> > namedVariables)
79  : m_allNamedVariables(std::move(namedVariables))
80  , m_identifier(identifier)
81 {
82 }
83 
85 {
87  using boost::algorithm::ends_with;
88  if (not m_weightfileRepresentation and
89  not(ends_with(m_identifier, ".root") or ends_with(m_identifier, ".xml"))) {
90  using DBWeightFileRepresentation = DBObjPtr<DatabaseRepresentationOfWeightfile>;
91  m_weightfileRepresentation = std::make_unique<DBWeightFileRepresentation>(m_identifier);
92  }
93 }
94 
96 {
97  std::unique_ptr<MVA::Weightfile> weightfile = getWeightFile();
98  if (weightfile) {
99  if (weightfile->getElement<std::string>("method") == "FastBDT" and
100  (weightfile->getElement<int>("FastBDT_version") == 1 or
101  weightfile->getElement<int>("FastBDT_version") == 2)) {
102 
103  int nExpectedVars = weightfile->getElement<int>("number_feature_variables");
104 
105  m_selectedNamedVariables.clear();
106  for (int iVar = 0; iVar < nExpectedVars; ++iVar) {
107  std::string variableElementName = "variable" + std::to_string(iVar);
108  std::string expectedName = weightfile->getElement<std::string>(variableElementName);
109 
110  auto itNamedVariable = std::find_if(m_allNamedVariables.begin(),
111  m_allNamedVariables.end(),
112  [expectedName](const Named<Float_t*>& namedVariable) {
113  return namedVariable.getName() == expectedName;
114  });
115 
116  if (itNamedVariable == m_allNamedVariables.end()) {
117  B2ERROR("Variable name " << iVar << " mismatch for FastBDT. " <<
118  "Could not find expected variable '" << expectedName << "'");
119  }
120  m_selectedNamedVariables.push_back(*itNamedVariable);
121  }
122  B2ASSERT("Number of variables mismatch", nExpectedVars == static_cast<int>(m_selectedNamedVariables.size()));
123  } else {
124  B2WARNING("Unpacked new kind of classifier. Consider to extend the feature variable check. Identifier name: " << m_identifier
125  << "; method name: " << weightfile->getElement<std::string>("method"));
126  m_selectedNamedVariables = m_allNamedVariables;
127  }
128 
129  std::map<std::string, MVA::AbstractInterface*> supportedInterfaces =
131  MVA::GeneralOptions generalOptions;
132  weightfile->getOptions(generalOptions);
133  m_expert = supportedInterfaces[generalOptions.m_method]->getExpert();
134  m_expert->load(*weightfile);
135 
136  std::vector<float> dummy;
137  dummy.resize(m_selectedNamedVariables.size(), 0);
138  m_dataset = std::make_unique<MVA::SingleDataset>(generalOptions, std::move(dummy), 0);
139  } else {
140  B2ERROR("Could not find weight file for identifier " << m_identifier);
141  }
142 }
143 
144 std::unique_ptr<MVA::Weightfile> MVAExpert::Impl::getWeightFile()
145 {
146  if (m_weightfileRepresentation) {
147  std::stringstream ss((*m_weightfileRepresentation)->m_data);
148  return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromStream(ss));
149  } else {
150  std::string weightFilePath = FileSystem::findFile(m_identifier);
151  return std::make_unique<MVA::Weightfile>(MVA::Weightfile::loadFromFile(weightFilePath));
152  }
153 }
154 
156 {
157  if (not m_expert) {
158  B2ERROR("MVA Expert is not loaded! I will return 0");
159  return NAN;
160  }
161 
162  // Transfer the extracted values to the data set were the expert can find them
163  for (unsigned int i = 0; i < m_selectedNamedVariables.size(); ++i) {
164  m_dataset->m_input[i] = *m_selectedNamedVariables[i];
165  }
166  return m_expert->apply(*m_dataset)[0];
167 }
168 
170 MVAExpert::MVAExpert(const std::string& identifier,
171  std::vector<Named<Float_t*> > namedVariables)
172  : m_impl(std::make_unique<MVAExpert::Impl>(identifier, std::move(namedVariables)))
173 {
174 }
175 
176 MVAExpert::~MVAExpert() = default;
177 
179 {
180  return m_impl->initialize();
181 }
182 
184 {
185  return m_impl->beginRun();
186 }
187 
189 {
190  return m_impl->predict();
191 }
Belle2::TrackFindingCDC::Named< Float_t * >
Belle2::MVA::AbstractInterface::getSupportedInterfaces
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:55
Belle2::MVA::Weightfile::getOptions
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:76
Belle2::TrackFindingCDC::MVAExpert::Impl::m_selectedNamedVariables
std::vector< Named< Float_t * > > m_selectedNamedVariables
References to the selected named values from the source variable set.
Definition: MVAExpert.cc:56
Belle2::TrackFindingCDC::MVAExpert::Impl::beginRun
void beginRun()
Called once before a new run begins.
Definition: MVAExpert.cc:95
Belle2::TrackFindingCDC::MVAExpert::Impl::predict
double predict()
Get the MVA prediction.
Definition: MVAExpert.cc:155
Belle2::TrackFindingCDC::MVAExpert::Impl::initialize
void initialize()
Signal the beginning of the event processing.
Definition: MVAExpert.cc:84
Belle2::TrackFindingCDC::MVAExpert::Impl
Implementation of the class to interact with the MVA package.
Definition: MVAExpert.cc:42
Belle2::TrackFindingCDC::MVAExpert::Impl::m_allNamedVariables
std::vector< Named< Float_t * > > m_allNamedVariables
References to the all named values from the source variable set.
Definition: MVAExpert.cc:53
Belle2::TrackFindingCDC::MVAExpert::~MVAExpert
~MVAExpert()
Destructor must be defined in cpp because of PImpl pointer.
Belle2::TrackFindingCDC::MVAExpert::Impl::Impl
Impl(const std::string &identifier, std::vector< Named< Float_t * >> namedVariables)
constructor
Definition: MVAExpert.cc:77
Belle2::TrackFindingCDC::MVAExpert
Class to interact with the MVA package.
Definition: MVAExpert.h:36
Belle2::TrackFindingCDC::MVAExpert::Impl::m_dataset
std::unique_ptr< MVA::Dataset > m_dataset
Pointer to the current dataset.
Definition: MVAExpert.cc:65
Belle2::DBObjPtr
Class for accessing objects in the database.
Definition: DBObjPtr.h:31
Belle2::Named
A mixin class to attach a name to an object. Based on class with same name in CDC package.
Definition: Named.h:31
Belle2::MVA::Weightfile::getElement
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:153
Belle2::MVA::GeneralOptions::m_method
std::string m_method
Name of the MVA method to use.
Definition: Options.h:84
Belle2::TrackFindingCDC::MVAExpert::MVAExpert
MVAExpert(const std::string &identifier, std::vector< Named< Float_t * >> namedVariables)
Construct the Expert with the specified weight folder and the name of the training that was used in t...
Belle2::DatabaseRepresentationOfWeightfile
Database representation of a Weightfile object.
Definition: DatabaseRepresentationOfWeightfile.h:30
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::TrackFindingCDC::MVAExpert::Impl::m_expert
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
Definition: MVAExpert.cc:62
Belle2::MVA::Weightfile::loadFromFile
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:215
Belle2::TrackFindingCDC::MVAExpert::predict
double predict()
Evaluate the MVA method and return the MVAOutput.
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::MVA::AbstractInterface::initSupportedInterfaces
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:55
Belle2::TrackFindingCDC::MVAExpert::m_impl
std::unique_ptr< Impl > m_impl
Pointer to implementation hiding the details.
Definition: MVAExpert.h:61
Belle2::TrackFindingCDC::MVAExpert::Impl::getWeightFile
std::unique_ptr< MVA::Weightfile > getWeightFile()
Get the weight file.
Definition: MVAExpert.cc:144
Belle2::MVA::Weightfile::loadFromStream
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:260
Belle2::TrackFindingCDC::MVAExpert::beginRun
void beginRun()
Update the mva method to the new run.
Belle2::TrackFindingCDC::MVAExpert::initialize
void initialize()
Initialise the mva method.
Belle2::FileSystem::findFile
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:147
Belle2::TrackFindingCDC::MVAExpert::Impl::m_identifier
std::string m_identifier
DB identifier of the expert or file name.
Definition: MVAExpert.cc:68
Belle2::TrackFindingCDC::MVAExpert::Impl::m_weightfileRepresentation
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfileRepresentation
Database pointer to the Database representation of the weightfile.
Definition: MVAExpert.cc:59