Belle II Software light-2406-ragdoll
Trivial.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/Trivial.h>
10
11#include <framework/logging/Logger.h>
12
13namespace Belle2 {
18 namespace MVA {
19
20 void TrivialOptions::load(const boost::property_tree::ptree& pt)
21 {
22 int version = pt.get<int>("Trivial_version");
23 if (version != 1) {
24 B2ERROR("Unknown weightfile version " << std::to_string(version));
25 throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
26 }
27 m_output = pt.get<double>("Trivial_output");
28
29 auto numberOfOutputs = pt.get<unsigned int>("Trivial_number_of_multiple_outputs", 0u);
30 m_multiple_output.resize(numberOfOutputs);
31 for (unsigned int i = 0; i < numberOfOutputs; ++i) {
32 m_multiple_output[i] = pt.get<double>(std::string("Trivial_multiple_output") + std::to_string(i));
33 }
34
35 m_passthrough = pt.get<bool>("Trivial_passthrough", false);
36 }
37
38 void TrivialOptions::save(boost::property_tree::ptree& pt) const
39 {
40 pt.put("Trivial_version", 1);
41 pt.put("Trivial_output", m_output);
42 pt.put("Trivial_number_of_multiple_outputs", m_multiple_output.size());
43 for (unsigned int i = 0; i < m_multiple_output.size(); ++i) {
44 pt.put(std::string("Trivial_multiple_output") + std::to_string(i), m_multiple_output[i]);
45 }
46 pt.put("Trivial_passthrough", m_passthrough);
47 }
48
49 po::options_description TrivialOptions::getDescription()
50 {
51 po::options_description description("Trivial options");
52 description.add_options()
53 ("output", po::value<double>(&m_output),
54 "Outputs this value for all predictions in binary classification (unless passthrough is enabled).");
55 description.add_options()
56 ("multiple_output", po::value<std::vector<double>>(&m_multiple_output)->multitoken(),
57 "Outputs these values for their respective classes in multiclass classification (unless passthrough is enabled).");
58 description.add_options()
59 ("passthrough", po::value<bool>(&m_passthrough),
60 "If enabled, the method returns the value of the input variable. For binary classification this option requires the presence of only one input variable. For multiclass classification we require either one input variable which is returned for all classes, or an input variable per class.");
61 return description;
62 }
63
65 const TrivialOptions& specific_options) : Teacher(general_options),
66 m_specific_options(specific_options) { }
67
69 {
70 Weightfile weightfile;
71 weightfile.addOptions(m_general_options);
73 weightfile.addSignalFraction(training_data.getSignalFraction());
74 return weightfile;
75 }
76
78 {
79 weightfile.getOptions(m_general_options);
81 }
82
83 std::vector<float> TrivialExpert::apply(Dataset& test_data) const
84 {
86 if (m_general_options.m_variables.size() != 1) {
87 B2ERROR("Trivial method in passthrough mode requires exactly 1 input variables. Found " << m_general_options.m_variables.size());
88 }
89 }
90 std::vector<float> probabilities(test_data.getNumberOfEvents());
91 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
92 test_data.loadEvent(iEvent);
94 probabilities[iEvent] = test_data.m_input[0];
95 } else {
96 probabilities[iEvent] = m_specific_options.m_output;
97 }
98 }
99 return probabilities;
100 }
101
102 std::vector<std::vector<float>> TrivialExpert::applyMulticlass(Dataset& test_data) const
103 {
105 B2ERROR("The number of classes declared in the general options do not match the number of outputs declared in the specific options for the Trivial expert");
106 }
107
110 B2ERROR("Trivial method in passthrough mode requires either exactly one input variable or one per class, matching the number of classes declared in the general options. Found "
112 }
113 }
114
115 std::vector<std::vector<float>> probabilities(test_data.getNumberOfEvents(), std::vector<float>(m_general_options.m_nClasses));
116 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
117 test_data.loadEvent(iEvent);
118 for (unsigned int iClass = 0; iClass < m_general_options.m_nClasses; ++iClass) {
120 if (m_general_options.m_variables.size() == 1) {
121 probabilities[iEvent][iClass] = test_data.m_input[0];
122 } else {
123 probabilities[iEvent][iClass] = test_data.m_input[iClass];
124 }
125 } else {
126 probabilities[iEvent][iClass] = m_specific_options.m_multiple_output.at(iClass);
127 }
128 }
129 }
130 return probabilities;
131 }
132 }
134}
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
GeneralOptions m_general_options
General options loaded from the weightfile.
Definition: Expert.h:70
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
unsigned int m_nClasses
Number of classes in a classification problem.
Definition: Options.h:89
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
TrivialOptions m_specific_options
Method specific options.
Definition: Trivial.h:108
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: Trivial.cc:83
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: Trivial.cc:77
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: Trivial.cc:102
Options for the Trivial MVA method.
Definition: Trivial.h:28
double m_output
Output of the trivial method.
Definition: Trivial.h:53
std::vector< double > m_multiple_output
Output of the trivial method.
Definition: Trivial.h:54
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: Trivial.cc:49
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Trivial.cc:20
bool m_passthrough
Flag for passthrough setting.
Definition: Trivial.h:55
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: Trivial.cc:38
TrivialTeacher(const GeneralOptions &general_options, const TrivialOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: Trivial.cc:64
TrivialOptions m_specific_options
Method specific options.
Definition: Trivial.h:79
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: Trivial.cc:68
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24