Belle II Software development
Trivial.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/Trivial.h>
10
11#include <framework/logging/Logger.h>
12
13namespace Belle2 {
18 namespace MVA {
19
20 void TrivialOptions::load(const boost::property_tree::ptree& pt)
21 {
22 int version = pt.get<int>("Trivial_version");
23 if (version != 1) {
24 B2ERROR("Unknown weightfile version " << std::to_string(version));
25 throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
26 }
27 m_output = pt.get<double>("Trivial_output");
28
29 auto numberOfOutputs = pt.get<unsigned int>("Trivial_number_of_multiple_outputs", 0u);
30 m_multiple_output.resize(numberOfOutputs);
31 for (unsigned int i = 0; i < numberOfOutputs; ++i) {
32 m_multiple_output[i] = pt.get<double>(std::string("Trivial_multiple_output") + std::to_string(i));
33 }
34
35 m_passthrough = pt.get<bool>("Trivial_passthrough", false);
36 }
37
38 void TrivialOptions::save(boost::property_tree::ptree& pt) const
39 {
40 pt.put("Trivial_version", 1);
41 pt.put("Trivial_output", m_output);
42 pt.put("Trivial_number_of_multiple_outputs", m_multiple_output.size());
43 for (unsigned int i = 0; i < m_multiple_output.size(); ++i) {
44 pt.put(std::string("Trivial_multiple_output") + std::to_string(i), m_multiple_output[i]);
45 }
46 pt.put("Trivial_passthrough", m_passthrough);
47 }
48
49 po::options_description TrivialOptions::getDescription()
50 {
51 po::options_description description("Trivial options");
52 description.add_options()
53 ("output", po::value<double>(&m_output),
54 "Outputs this value for all predictions in binary classification (unless passthrough is enabled).");
55 description.add_options()
56 ("multiple_output", po::value<std::vector<double>>(&m_multiple_output)->multitoken(),
57 "Outputs these values for their respective classes in multiclass classification (unless passthrough is enabled).");
58 description.add_options()
59 ("passthrough", po::value<bool>(&m_passthrough),
60 "If enabled, the method returns the value of the input variable. For binary classification this option requires the presence of only one input variable. For multiclass classification we require either one input variable which is returned for all classes, or an input variable per class.");
61 return description;
62 }
63
65 const TrivialOptions& specific_options) : Teacher(general_options),
66 m_specific_options(specific_options) { }
67
69 {
70 Weightfile weightfile;
71 weightfile.addOptions(m_general_options);
73 weightfile.addSignalFraction(training_data.getSignalFraction());
74 return weightfile;
75 }
76
78 {
79 weightfile.getOptions(m_general_options);
81 }
82
83 std::vector<float> TrivialExpert::apply(Dataset& test_data) const
84 {
85 if (m_specific_options.m_passthrough) {
86 if (m_general_options.m_variables.size() != 1) {
87 B2ERROR("Trivial method in passthrough mode requires exactly 1 input variables. Found " << m_general_options.m_variables.size());
88 }
89 }
90 std::vector<float> probabilities(test_data.getNumberOfEvents());
91 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
92 test_data.loadEvent(iEvent);
93 if (m_specific_options.m_passthrough) {
94 probabilities[iEvent] = test_data.m_input[0];
95 } else {
96 probabilities[iEvent] = m_specific_options.m_output;
97 }
98 }
99 return probabilities;
100 }
101
102 std::vector<std::vector<float>> TrivialExpert::applyMulticlass(Dataset& test_data) const
103 {
104 if ((m_general_options.m_nClasses != m_specific_options.m_multiple_output.size()) and (not m_specific_options.m_passthrough)) {
105 B2ERROR("The number of classes declared in the general options do not match the number of outputs declared in the specific options for the Trivial expert");
106 }
107
108 if (m_specific_options.m_passthrough) {
109 if ((m_general_options.m_variables.size() != 1) and (m_general_options.m_variables.size() != m_general_options.m_nClasses)) {
110 B2ERROR("Trivial method in passthrough mode requires either exactly one input variable or one per class, matching the number of classes declared in the general options. Found "
111 << m_general_options.m_variables.size());
112 }
113 }
114
115 std::vector<std::vector<float>> probabilities(test_data.getNumberOfEvents(), std::vector<float>(m_general_options.m_nClasses));
116 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
117 test_data.loadEvent(iEvent);
118 for (unsigned int iClass = 0; iClass < m_general_options.m_nClasses; ++iClass) {
119 if (m_specific_options.m_passthrough) {
120 if (m_general_options.m_variables.size() == 1) {
121 probabilities[iEvent][iClass] = test_data.m_input[0];
122 } else {
123 probabilities[iEvent][iClass] = test_data.m_input[iClass];
124 }
125 } else {
126 probabilities[iEvent][iClass] = m_specific_options.m_multiple_output.at(iClass);
127 }
128 }
129 }
130 return probabilities;
131 }
132 }
134}
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition Dataset.h:33
GeneralOptions m_general_options
General options loaded from the weightfile.
Definition Expert.h:70
General options which are shared by all MVA trainings.
Definition Options.h:62
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition Teacher.h:49
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
Definition Teacher.cc:18
TrivialOptions m_specific_options
Method specific options.
Definition Trivial.h:108
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition Trivial.cc:83
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition Trivial.cc:77
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition Trivial.cc:102
Options for the Trivial MVA method.
Definition Trivial.h:28
double m_output
Output of the trivial method.
Definition Trivial.h:53
std::vector< double > m_multiple_output
Output of the trivial method.
Definition Trivial.h:54
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition Trivial.cc:49
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition Trivial.cc:20
bool m_passthrough
Flag for passthrough setting.
Definition Trivial.h:55
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition Trivial.cc:38
TrivialTeacher(const GeneralOptions &general_options, const TrivialOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition Trivial.cc:64
TrivialOptions m_specific_options
Method specific options.
Definition Trivial.h:79
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition Trivial.cc:68
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition Weightfile.cc:67
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition Weightfile.cc:95
Abstract base class for different kinds of events.