Belle II Software  release-06-00-14
MVAExpertModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 
10 #include <mva/modules/MVAExpert/MVAExpertModule.h>
11 
12 #include <analysis/dataobjects/Particle.h>
13 #include <analysis/dataobjects/ParticleList.h>
14 #include <analysis/dataobjects/ParticleExtraInfoMap.h>
15 #include <analysis/dataobjects/EventExtraInfo.h>
16 
17 #include <mva/interface/Interface.h>
18 
19 #include <boost/algorithm/string/predicate.hpp>
20 #include <memory>
21 
22 #include <framework/logging/Logger.h>
23 
24 
25 namespace Belle2 {
31  REG_MODULE(MVAExpert)
32 
34  {
35  setDescription("Adds an ExtraInfo to the Particle objects in given ParticleLists which is calcuated by an expert defined by a weightfile.");
36  setPropertyFlags(c_ParallelProcessingCertified);
37 
38  std::vector<std::string> empty;
39  addParam("listNames", m_listNames,
40  "Particles from these ParticleLists are used as input. If no name is given the expert is applied to every event once, and one can only use variables which accept nullptr as Particle*",
41  empty);
42  addParam("extraInfoName", m_extraInfoName,
43  "Name under which the output of the expert is stored in the ExtraInfo of the Particle object.");
44  addParam("identifier", m_identifier, "The database identifier which is used to load the weights during the training.");
45  addParam("signalFraction", m_signal_fraction_override,
46  "signalFraction to calculate probability (if -1 the signalFraction of the training data is used)", -1.0);
47  }
48 
50  {
51  // All specified ParticleLists are required to exist
52  for (auto& name : m_listNames) {
53  StoreObjPtr<ParticleList> list(name);
54  list.isRequired();
55  }
56 
57  if (m_listNames.empty()) {
59  extraInfo.registerInDataStore();
60  } else {
62  extraInfo.registerInDataStore();
63  }
64 
65  if (not(boost::ends_with(m_identifier, ".root") or boost::ends_with(m_identifier, ".xml"))) {
66  m_weightfile_representation = std::make_unique<DBObjPtr<DatabaseRepresentationOfWeightfile>>(
67  MVA::makeSaveForDatabase(m_identifier));
68  }
70 
71  }
72 
74  {
75 
77  if (m_weightfile_representation->hasChanged()) {
78  std::stringstream ss((*m_weightfile_representation)->m_data);
79  auto weightfile = MVA::Weightfile::loadFromStream(ss);
80  init_mva(weightfile);
81  }
82  } else {
84  init_mva(weightfile);
85  }
86 
87  }
88 
90  {
91 
92  auto supported_interfaces = MVA::AbstractInterface::getSupportedInterfaces();
93  MVA::GeneralOptions general_options;
94  weightfile.getOptions(general_options);
95 
96  // Overwrite signal fraction from training
99 
100  m_expert = supported_interfaces[general_options.m_method]->getExpert();
101  m_expert->load(weightfile);
102 
104  m_feature_variables = manager.getVariables(general_options.m_variables);
105  if (m_feature_variables.size() != general_options.m_variables.size()) {
106  B2FATAL("One or more feature variables could not be loaded via the Variable::Manager. Check the names!");
107  }
108 
109  std::vector<float> dummy;
110  dummy.resize(m_feature_variables.size(), 0);
111  m_dataset = std::make_unique<MVA::SingleDataset>(general_options, dummy, 0);
112 
113  }
114 
116  {
117  if (not m_expert) {
118  B2ERROR("MVA Expert is not loaded! I will return 0");
119  return 0.0;
120  }
121  for (unsigned int i = 0; i < m_feature_variables.size(); ++i) {
122  m_dataset->m_input[i] = m_feature_variables[i]->function(particle);
123  }
124  return m_expert->apply(*m_dataset)[0];
125  }
126 
127 
129  {
130  for (auto& listName : m_listNames) {
131  StoreObjPtr<ParticleList> list(listName);
132  // Calculate target Value for Particles
133  for (unsigned i = 0; i < list->getListSize(); ++i) {
134  Particle* particle = list->getParticle(i);
135  float targetValue = analyse(particle);
136  if (particle->hasExtraInfo(m_extraInfoName)) {
137  if (particle->getExtraInfo(m_extraInfoName) != targetValue) {
138  B2WARNING("Extra Info with given name is already set! Overwriting old value!");
139  particle->setExtraInfo(m_extraInfoName, targetValue);
140  }
141  } else {
142  particle->addExtraInfo(m_extraInfoName, targetValue);
143  }
144  }
145  }
146  if (m_listNames.empty()) {
147  StoreObjPtr<EventExtraInfo> eventExtraInfo;
148  if (not eventExtraInfo.isValid())
149  eventExtraInfo.create();
150  if (eventExtraInfo->hasExtraInfo(m_extraInfoName)) {
151  B2WARNING("Extra Info with given name is already set! I won't set it again!");
152  } else {
153  float targetValue = analyse(nullptr);
154  eventExtraInfo->addExtraInfo(m_extraInfoName, targetValue);
155  }
156  }
157  }
158 
160 } // Belle2 namespace
@ c_Event
Different object in each event, all objects/arrays are invalidated after event() function has been ca...
Definition: DataStore.h:59
This module adds an ExtraInfo to the Particle objects in a given ParticleList.
std::unique_ptr< MVA::SingleDataset > m_dataset
Pointer to the current dataset.
std::vector< const Variable::Manager::Var * > m_feature_variables
Pointers to the feature variables.
std::unique_ptr< MVA::Expert > m_expert
Pointer to the current MVA Expert.
double m_signal_fraction_override
Signal Fraction which should be used.
std::unique_ptr< DBObjPtr< DatabaseRepresentationOfWeightfile > > m_weightfile_representation
Database pointer to the Database representation of the weightfile.
std::vector< std::string > m_listNames
input particle list names
std::string m_extraInfoName
Name under which the SignalProbability is stored in the extraInfo of the Particle object.
std::string m_identifier
weight-file
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:250
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:66
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:205
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:94
Base class for Modules.
Definition: Module.h:72
Class to store reconstructed particles.
Definition: Particle.h:74
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
Type-safe access to single objects in the data store.
Definition: StoreObjPtr.h:95
Global list of available variables.
Definition: Manager.h:98
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
virtual void initialize() override
Initialize the module.
virtual void event() override
Called for each event.
virtual void beginRun() override
Called at the beginning of a new run.
float analyse(Particle *)
Calculates expert output for given Particle pointer.
void init_mva(MVA::Weightfile &weightfile)
Initialize mva expert, dataset and features Called every time the weightfile in the database changes ...
Abstract base class for different kinds of events.