Belle II Software light-2406-ragdoll
test_TMVA.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/TMVA.h>
10#include <mva/interface/Interface.h>
11#include <framework/utilities/FileSystem.h>
12#include <framework/utilities/TestHelpers.h>
13
14#include <gtest/gtest.h>
15
16using namespace Belle2;
17
18namespace {
19
20 TEST(TMVATest, TMVAOptions)
21 {
22 MVA::TMVAOptions specific_options;
23
24 //EXPECT_EQ(specific_options.method, "FastBDT");
25 //EXPECT_EQ(specific_options.type, "Plugins");
26 //EXPECT_EQ(specific_options.config, "!H:!V:CreateMVAPdfs:NTrees=400:Shrinkage=0.10:RandRatio=0.5:NCutLevel=8:NTreeLayers=3");
27 EXPECT_EQ(specific_options.m_method, "BDT");
28 EXPECT_EQ(specific_options.m_type, "BDT");
29 EXPECT_EQ(specific_options.m_config,
30 "!H:!V:CreateMVAPdfs:NTrees=400:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=1024:MaxDepth=3:IgnoreNegWeightsInTraining");
31 EXPECT_EQ(specific_options.m_factoryOption, "!V:!Silent:Color:DrawProgressBar");
32 EXPECT_EQ(specific_options.m_prepareOption, "SplitMode=random:!V");
33 EXPECT_EQ(specific_options.m_workingDirectory, "");
34 EXPECT_EQ(specific_options.m_prefix, "TMVA");
35
36 specific_options.m_method = "Method";
37 specific_options.m_type = "Type";
38 specific_options.m_config = "Config";
39 specific_options.m_factoryOption = "FactoryOption";
40 specific_options.m_prepareOption = "PrepareOption";
41 specific_options.m_workingDirectory = "WorkingDirectory";
42 specific_options.m_prefix = "Prefix";
43
44 boost::property_tree::ptree pt;
45 specific_options.save(pt);
46 EXPECT_EQ(pt.get<std::string>("TMVA_method"), "Method");
47 EXPECT_EQ(pt.get<std::string>("TMVA_type"), "Type");
48 EXPECT_EQ(pt.get<std::string>("TMVA_config"), "Config");
49 EXPECT_EQ(pt.get<std::string>("TMVA_factoryOption"), "FactoryOption");
50 EXPECT_EQ(pt.get<std::string>("TMVA_prepareOption"), "PrepareOption");
51 EXPECT_EQ(pt.get<std::string>("TMVA_workingDirectory"), "WorkingDirectory");
52 EXPECT_EQ(pt.get<std::string>("TMVA_prefix"), "Prefix");
53
54 MVA::TMVAOptions specific_options2;
55 specific_options2.load(pt);
56
57 EXPECT_EQ(specific_options2.m_method, "Method");
58 EXPECT_EQ(specific_options2.m_type, "Type");
59 EXPECT_EQ(specific_options2.m_config, "Config");
60 EXPECT_EQ(specific_options2.m_factoryOption, "FactoryOption");
61 EXPECT_EQ(specific_options2.m_prepareOption, "PrepareOption");
62 EXPECT_EQ(specific_options2.m_workingDirectory, "WorkingDirectory");
63 EXPECT_EQ(specific_options2.m_prefix, "Prefix");
64
65 MVA::TMVAOptionsClassification specific_classification_options;
66 EXPECT_EQ(specific_classification_options.transform2probability, true);
67 EXPECT_EQ(specific_classification_options.m_factoryOption, "!V:!Silent:Color:DrawProgressBar:AnalysisType=Classification");
68
69 specific_classification_options.transform2probability = false;
70 boost::property_tree::ptree pt_classification;
71 specific_classification_options.save(pt_classification);
72 EXPECT_EQ(pt_classification.get<bool>("TMVA_transform2probability"), false);
73
74 MVA::TMVAOptionsClassification specific_classification_options2;
75 specific_classification_options2.load(pt_classification);
76 EXPECT_EQ(specific_classification_options.transform2probability, false);
77
78 MVA::TMVAOptionsRegression specific_regression_options;
79 EXPECT_EQ(specific_regression_options.m_factoryOption, "!V:!Silent:Color:DrawProgressBar:AnalysisType=Regression");
80
81 EXPECT_EQ(specific_classification_options.getMethod(), std::string("TMVAClassification"));
82 EXPECT_EQ(specific_regression_options.getMethod(), std::string("TMVARegression"));
83
84 // Test if po::options_description is created without crashing
85 auto description = specific_options.getDescription();
86 EXPECT_EQ(description.options().size(), 6);
87
88 auto description_reg = specific_regression_options.getDescription();
89 EXPECT_EQ(description_reg.options().size(), 6);
90
91 auto description_cls = specific_classification_options.getDescription();
92 EXPECT_EQ(description_cls.options().size(), 7);
93
94 // Check for B2ERROR and throw if version is wrong
95 // we try with version 100, surely we will never reach this!
96 pt.put("TMVA_version", 100);
97 try {
98 EXPECT_B2ERROR(specific_options2.load(pt));
99 } catch (...) {
100
101 }
102 EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
103
104 }
105
106 class TestClassificationDataset : public MVA::Dataset {
107 public:
108 explicit TestClassificationDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
109 {
110 m_input = {0.0};
111 m_target = 0.0;
112 m_isSignal = false;
113 m_weight = 1.0;
114 }
115
116 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
117 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
118 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
119 void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
120 float getSignalFraction() override { return 0.1; };
121 std::vector<float> getFeature(unsigned int) override { return m_data; }
122
123 std::vector<float> m_data;
124
125 };
126
127
128 TEST(TMVATest, TMVAClassificationInterface)
129 {
131 interface;
132
133 MVA::GeneralOptions general_options;
134 general_options.m_variables = {"A"};
135 general_options.m_target_variable = "Target";
136 MVA::TMVAOptionsClassification specific_options;
137 specific_options.m_prepareOption = "SplitMode=block:!V";
138 specific_options.transform2probability = false;
139 specific_options.m_config =
140 "!H:!V:CreateMVAPdfs:NTrees=400:BoostType=Grad:Shrinkage=0.1:nCuts=10:MaxDepth=3:IgnoreNegWeightsInTraining:MinNodeSize=20";
141 TestClassificationDataset dataset({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
142 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
143 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
144 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
145 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
146 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
147 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
148 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0
149 });
150
151 auto teacher = interface.getTeacher(general_options, specific_options);
152 auto weightfile = teacher->train(dataset);
153
154 auto expert = interface.getExpert();
155 expert->load(weightfile);
156 auto probabilities = expert->apply(dataset);
157 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
158 for (unsigned int i = 0; i < 24; ++i) {
159 EXPECT_LE(probabilities[i], 0.1);
160 EXPECT_GE(probabilities[i], -0.1);
161 }
162 for (unsigned int i = 24; i < 48; i += 2) {
163 EXPECT_LE(probabilities[i], -0.8);
164 EXPECT_GE(probabilities[i + 1], 0.8);
165 }
166 for (unsigned int i = 48; i < 72; ++i) {
167 EXPECT_LE(probabilities[i], 0.1);
168 EXPECT_GE(probabilities[i], -0.1);
169 }
170 for (unsigned int i = 72; i < 96; i += 2) {
171 EXPECT_LE(probabilities[i], -0.8);
172 EXPECT_GE(probabilities[i + 1], 0.8);
173 }
174
175
176 }
177
178
179 class TestRegressionDataset : public MVA::Dataset {
180 public:
181 explicit TestRegressionDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
182 {
183 m_input = {0.0};
184 m_target = 0.0;
185 m_isSignal = false;
186 m_weight = 1.0;
187 }
188
189 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
190 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
191 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
192 void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = static_cast<float>((static_cast<int>(iEvent % 48) - 24) / 4) / 24.0;};
193 float getSignalFraction() override { return 0.0; };
194 std::vector<float> getFeature(unsigned int) override { return m_data; }
195
196 std::vector<float> m_data;
197
198 };
199
200 TEST(TMVATest, TMVARegressionInterface)
201 {
203
204 MVA::GeneralOptions general_options;
205 general_options.m_variables = {"A"};
206 general_options.m_target_variable = "Target";
207 MVA::TMVAOptionsRegression specific_options;
208 specific_options.m_prepareOption = "SplitMode=block:!V";
209 specific_options.m_config = "!H:!V:NTrees=200::BoostType=Grad:Shrinkage=0.1:nCuts=24:MaxDepth=3";
210 //specific_options.config = "nCuts=120:NTrees=20:MaxDepth=4:BoostType=AdaBoostR2:SeparationType=RegressionVariance:MinNodeSize=10";
211 TestRegressionDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0,
212 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0,
213 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0,
214 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0,
215 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0,
216 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0,
217 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0,
218 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0
219 });
220
221 auto teacher = interface.getTeacher(general_options, specific_options);
222 auto weightfile = teacher->train(dataset);
223
224 auto expert = interface.getExpert();
225 expert->load(weightfile);
226 auto values = expert->apply(dataset);
227 EXPECT_EQ(values.size(), dataset.getNumberOfEvents());
228 for (unsigned int i = 0; i < 96; i += 4) {
229 float r = static_cast<float>((static_cast<int>(i % 48) - 24) / 4) / 24.0;
230 EXPECT_LE(values[i], r + 0.05);
231 EXPECT_GE(values[i], r - 0.05);
232 EXPECT_LE(values[i + 1], r + 0.05);
233 EXPECT_GE(values[i + 1], r - 0.05);
234 EXPECT_LE(values[i + 2], r + 0.05);
235 EXPECT_GE(values[i + 2], r - 0.05);
236 EXPECT_LE(values[i + 3], r + 0.05);
237 EXPECT_GE(values[i + 3], r - 0.05);
238 }
239
240 }
241
242 TEST(TMVATest, WeightfilesAreReadCorrectly)
243 {
245 interface;
246
247 MVA::GeneralOptions general_options;
248 general_options.m_variables = {"M", "p", "pt"};
249 MVA::MultiDataset dataset(general_options, {{1.835127, 1.179507, 1.164944},
250 {1.873689, 1.881940, 1.843310},
251 {1.863657, 1.774831, 1.753773},
252 {1.858293, 1.605311, 0.631336},
253 {1.837129, 1.575739, 1.490166},
254 {1.811395, 1.524029, 0.565220}
255 },
256 {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
257
258 auto expert = interface.getExpert();
259
260 auto weightfile = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/TMVA.xml"));
261 expert->load(weightfile);
262 auto probabilities = expert->apply(dataset);
263 EXPECT_NEAR(probabilities[0], 0.098980136215686798, 0.0001);
264 EXPECT_NEAR(probabilities[1], 0.35516414046287537, 0.0001);
265 EXPECT_NEAR(probabilities[2], 0.066082566976547241, 0.0001);
266 EXPECT_NEAR(probabilities[3], 0.18826344609260559, 0.0001);
267 EXPECT_NEAR(probabilities[4], 0.10691597312688828, 0.0001);
268 EXPECT_NEAR(probabilities[5], 1.4245844629813542e-13, 0.0001);
269 }
270
271}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:151
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:74
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:35
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:125
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:116
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
Options for the TMVA Classification MVA method.
Definition: TMVA.h:80
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:112
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:81
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:115
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:69
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:75
Options for the TMVA Regression MVA method.
Definition: TMVA.h:166
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:181
Options for the TMVA MVA method.
Definition: TMVA.h:34
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:55
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_type
tmva method type
Definition: TMVA.h:61
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:27
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24