Belle II Software light-2406-ragdoll
test_Options.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/interface/Options.h>
10
11#include <boost/property_tree/ptree.hpp>
12#include <gtest/gtest.h>
13
14using namespace Belle2;
15
16namespace {
17
18 TEST(OptionsTest, GeneralOptions)
19 {
20
21 MVA::GeneralOptions general_options;
22
23 EXPECT_EQ(general_options.m_method, "");
24 EXPECT_EQ(general_options.m_identifier, "");
25 EXPECT_EQ(general_options.m_datafiles.size(), 0);
26 EXPECT_EQ(general_options.m_treename, "ntuple");
27 EXPECT_EQ(general_options.m_variables.size(), 0);
28 EXPECT_EQ(general_options.m_spectators.size(), 0);
29 EXPECT_EQ(general_options.m_signal_class, 1);
30 EXPECT_EQ(general_options.m_nClasses, 2);
31 EXPECT_EQ(general_options.m_target_variable, "isSignal");
32 EXPECT_EQ(general_options.m_weight_variable, "__weight__");
33 EXPECT_EQ(general_options.m_max_events, 0u);
34
35 general_options.m_method = "Method";
36 general_options.m_identifier = "Weightfile";
37 general_options.m_datafiles = {"Datafile"};
38 general_options.m_treename = "Tree";
39 general_options.m_variables = {"v", "a", "r", "s"};
40 general_options.m_spectators = {"x", "M"};
41 general_options.m_signal_class = 2;
42 general_options.m_nClasses = 4;
43 general_options.m_max_events = 100;
44 general_options.m_target_variable = "Target";
45 general_options.m_weight_variable = "Weight";
46
47 boost::property_tree::ptree pt;
48 general_options.save(pt);
49 EXPECT_EQ(pt.get<std::string>("method"), "Method");
50 EXPECT_EQ(pt.get<std::string>("weightfile"), "Weightfile");
51 EXPECT_EQ(pt.get<unsigned int>("number_data_files"), 1);
52 EXPECT_EQ(pt.get<std::string>("datafile0"), "Datafile");
53 EXPECT_EQ(pt.get<std::string>("treename"), "Tree");
54 EXPECT_EQ(pt.get<std::string>("target_variable"), "Target");
55 EXPECT_EQ(pt.get<std::string>("weight_variable"), "Weight");
56 EXPECT_EQ(pt.get<int>("signal_class"), 2);
57 EXPECT_EQ(pt.get<unsigned int>("nClasses"), 4);
58 EXPECT_EQ(pt.get<unsigned int>("max_events"), 100u);
59 EXPECT_EQ(pt.get<unsigned int>("number_feature_variables"), 4);
60 EXPECT_EQ(pt.get<std::string>("variable0"), "v");
61 EXPECT_EQ(pt.get<std::string>("variable1"), "a");
62 EXPECT_EQ(pt.get<std::string>("variable2"), "r");
63 EXPECT_EQ(pt.get<std::string>("variable3"), "s");
64 EXPECT_EQ(pt.get<unsigned int>("number_spectator_variables"), 2);
65 EXPECT_EQ(pt.get<std::string>("spectator0"), "x");
66 EXPECT_EQ(pt.get<std::string>("spectator1"), "M");
67
68 MVA::GeneralOptions general_options2;
69 general_options2.load(pt);
70
71 EXPECT_EQ(general_options2.m_method, "Method");
72 EXPECT_EQ(general_options2.m_identifier, "Weightfile");
73 EXPECT_EQ(general_options2.m_datafiles.size(), 1);
74 EXPECT_EQ(general_options2.m_datafiles[0], "Datafile");
75 EXPECT_EQ(general_options2.m_treename, "Tree");
76 EXPECT_EQ(general_options2.m_variables.size(), 4);
77 EXPECT_EQ(general_options2.m_variables[0], "v");
78 EXPECT_EQ(general_options2.m_variables[1], "a");
79 EXPECT_EQ(general_options2.m_variables[2], "r");
80 EXPECT_EQ(general_options2.m_variables[3], "s");
81 EXPECT_EQ(general_options2.m_spectators.size(), 2);
82 EXPECT_EQ(general_options2.m_spectators[0], "x");
83 EXPECT_EQ(general_options2.m_spectators[1], "M");
84 EXPECT_EQ(general_options2.m_signal_class, 2);
85 EXPECT_EQ(general_options2.m_nClasses, 4);
86 EXPECT_EQ(general_options2.m_max_events, 100u);
87 EXPECT_EQ(general_options2.m_target_variable, "Target");
88 EXPECT_EQ(general_options2.m_weight_variable, "Weight");
89
90 // Test if po::options_description is created without crashing
91 auto description = general_options.getDescription();
92 EXPECT_EQ(description.options().size(), 12);
93 }
94
95 TEST(OptionsTest, MetaOptions)
96 {
97 MVA::MetaOptions meta_options;
98 EXPECT_EQ(meta_options.m_use_splot, false);
99 EXPECT_EQ(meta_options.m_splot_variable, "M");
100 EXPECT_EQ(meta_options.m_splot_mc_files.size(), 0);
101 EXPECT_EQ(meta_options.m_splot_combined, false);
102 EXPECT_EQ(meta_options.m_splot_boosted, false);
103 EXPECT_EQ(meta_options.m_use_sideband_subtraction, false);
104 EXPECT_EQ(meta_options.m_sideband_variable, "");
105 EXPECT_EQ(meta_options.m_sideband_mc_files.size(), 0u);
106 EXPECT_EQ(meta_options.m_use_reweighting, false);
107 EXPECT_EQ(meta_options.m_reweighting_identifier, "");
108 EXPECT_EQ(meta_options.m_reweighting_variable, "");
109 EXPECT_EQ(meta_options.m_reweighting_data_files.size(), 0u);
110 EXPECT_EQ(meta_options.m_reweighting_mc_files.size(), 0u);
111
112 meta_options.m_use_reweighting = true;
113 meta_options.m_reweighting_identifier = "test";
114 meta_options.m_reweighting_variable = "A";
115 meta_options.m_reweighting_mc_files = {"reweighting_mc.root"};
116 meta_options.m_reweighting_data_files = {"reweighting_data.root"};
117 meta_options.m_use_sideband_subtraction = true;
118 meta_options.m_sideband_variable = "B";
119 meta_options.m_sideband_mc_files = {"sideband_mc.root"};
120 meta_options.m_use_splot = true;
121 meta_options.m_splot_variable = "Q";
122 meta_options.m_splot_mc_files = {"splot_mc.root"};
123 meta_options.m_splot_combined = true;
124 meta_options.m_splot_boosted = true;
125
126 boost::property_tree::ptree pt;
127 meta_options.save(pt);
128 EXPECT_EQ(pt.get<bool>("use_splot"), true);
129 EXPECT_EQ(pt.get<bool>("splot_combined"), true);
130 EXPECT_EQ(pt.get<bool>("splot_boosted"), true);
131 EXPECT_EQ(pt.get<unsigned int>("splot_number_of_mc_files"), 1);
132 EXPECT_EQ(pt.get<std::string>("splot_mc_file0"), "splot_mc.root");
133 EXPECT_EQ(pt.get<std::string>("splot_variable"), "Q");
134 EXPECT_EQ(pt.get<bool>("use_sideband_subtraction"), true);
135 EXPECT_EQ(pt.get<std::string>("sideband_variable"), "B");
136 EXPECT_EQ(pt.get<bool>("use_reweighting"), true);
137 EXPECT_EQ(pt.get<std::string>("reweighting_identifier"), "test");
138 EXPECT_EQ(pt.get<std::string>("reweighting_variable"), "A");
139 EXPECT_EQ(pt.get<unsigned int>("reweighting_number_of_mc_files"), 1);
140 EXPECT_EQ(pt.get<std::string>("reweighting_mc_file0"), "reweighting_mc.root");
141 EXPECT_EQ(pt.get<unsigned int>("reweighting_number_of_data_files"), 1);
142 EXPECT_EQ(pt.get<std::string>("reweighting_data_file0"), "reweighting_data.root");
143 EXPECT_EQ(pt.get<unsigned int>("sideband_number_of_mc_files"), 1);
144 EXPECT_EQ(pt.get<std::string>("sideband_mc_file0"), "sideband_mc.root");
145
146 MVA::MetaOptions meta_options2;
147 meta_options2.load(pt);
148
149 EXPECT_EQ(meta_options2.m_use_splot, true);
150 EXPECT_EQ(meta_options2.m_splot_variable, "Q");
151 EXPECT_EQ(meta_options2.m_splot_mc_files.size(), 1);
152 EXPECT_EQ(meta_options2.m_splot_mc_files[0], "splot_mc.root");
153 EXPECT_EQ(meta_options2.m_splot_combined, true);
154 EXPECT_EQ(meta_options2.m_splot_boosted, true);
155 EXPECT_EQ(meta_options2.m_use_sideband_subtraction, true);
156 EXPECT_EQ(meta_options2.m_sideband_variable, "B");
157 EXPECT_EQ(meta_options2.m_sideband_mc_files.size(), 1);
158 EXPECT_EQ(meta_options2.m_sideband_mc_files[0], "sideband_mc.root");
159 EXPECT_EQ(meta_options2.m_use_reweighting, true);
160 EXPECT_EQ(meta_options2.m_reweighting_identifier, "test");
161 EXPECT_EQ(meta_options2.m_reweighting_variable, "A");
162 EXPECT_EQ(meta_options2.m_reweighting_mc_files.size(), 1);
163 EXPECT_EQ(meta_options2.m_reweighting_mc_files[0], "reweighting_mc.root");
164 EXPECT_EQ(meta_options2.m_reweighting_data_files.size(), 1);
165 EXPECT_EQ(meta_options2.m_reweighting_data_files[0], "reweighting_data.root");
166
167 // Test if po::options_description is created without crashing
168 auto description = meta_options.getDescription();
169 EXPECT_EQ(description.options().size(), 13);
170
171 }
172
173}
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
Definition: Options.h:84
int m_signal_class
Signal class which is used as signal in a classification problem.
Definition: Options.h:88
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_method
Name of the MVA method to use.
Definition: Options.h:82
unsigned int m_max_events
Maximum number of events to process, 0 means all.
Definition: Options.h:92
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism (used by Weightfile) to load Options from a xml tree.
Definition: Options.cc:44
std::string m_treename
Name of the TTree inside the datafile containing the training data.
Definition: Options.h:85
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
unsigned int m_nClasses
Number of classes in a classification problem.
Definition: Options.h:89
std::string m_identifier
Identifier containing the finished training.
Definition: Options.h:83
Meta Options which modify the underlying training by doing sPlot, Multiclass and HyperparameterSearch...
Definition: Options.h:111
std::string m_reweighting_variable
Variable defining for which events the reweighting should be used (1) or not used (0).
Definition: Options.h:144
bool m_use_reweighting
Use a pretraining of data against mc and weight the mc afterwards.
Definition: Options.h:142
bool m_use_splot
Use splot training.
Definition: Options.h:131
std::string m_splot_variable
Discriminating variable.
Definition: Options.h:132
std::vector< std::string > m_reweighting_mc_files
MC files for the pretraining.
Definition: Options.h:147
std::vector< std::string > m_reweighting_data_files
Data files for the pretraining.
Definition: Options.h:146
bool m_splot_combined
Combine sPlot training with PDF classifier for discriminating variable.
Definition: Options.h:134
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism (used by Weightfile) to load Options from a xml tree.
Definition: Options.cc:128
std::string m_reweighting_identifier
Identifier used to save the reweighting expert.
Definition: Options.h:143
bool m_splot_boosted
Use boosted sPlot training (aPlot)
Definition: Options.h:135
std::vector< std::string > m_splot_mc_files
Monte carlo files used for the distribution of the discriminating variable.
Definition: Options.h:133
std::string m_sideband_variable
Variable defining the signal region (1) background region (2) negative signal region (3) or unused (o...
Definition: Options.h:139
std::vector< std::string > m_sideband_mc_files
used to estimate the number of events in the different regions
Definition: Options.h:138
bool m_use_sideband_subtraction
Use sideband subtraction.
Definition: Options.h:137
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24