Belle II Software development
test_TMVA.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/TMVA.h>
10#include <mva/interface/Interface.h>
11#include <framework/utilities/FileSystem.h>
12#include <framework/utilities/TestHelpers.h>
13
14#include <gtest/gtest.h>
15
16using namespace Belle2;
17
18namespace {
19
20 TEST(TMVATest, TMVAOptions)
21 {
22 MVA::TMVAOptions specific_options;
23
24 //EXPECT_EQ(specific_options.method, "FastBDT");
25 //EXPECT_EQ(specific_options.type, "Plugins");
26 //EXPECT_EQ(specific_options.config, "!H:!V:CreateMVAPdfs:NTrees=400:Shrinkage=0.10:RandRatio=0.5:NCutLevel=8:NTreeLayers=3");
27 EXPECT_EQ(specific_options.m_method, "BDT");
28 EXPECT_EQ(specific_options.m_type, "BDT");
29 EXPECT_EQ(specific_options.m_config,
30 "!H:!V:CreateMVAPdfs:NTrees=400:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=1024:MaxDepth=3:IgnoreNegWeightsInTraining");
31 EXPECT_EQ(specific_options.m_factoryOption, "!V:!Silent:Color:DrawProgressBar");
32 EXPECT_EQ(specific_options.m_prepareOption, "SplitMode=random:!V");
33 EXPECT_EQ(specific_options.m_workingDirectory, "");
34 EXPECT_EQ(specific_options.m_prefix, "TMVA");
35
36 specific_options.m_method = "Method";
37 specific_options.m_type = "Type";
38 specific_options.m_config = "Config";
39 specific_options.m_factoryOption = "FactoryOption";
40 specific_options.m_prepareOption = "PrepareOption";
41 specific_options.m_workingDirectory = "WorkingDirectory";
42 specific_options.m_prefix = "Prefix";
43
44 boost::property_tree::ptree pt;
45 specific_options.save(pt);
46 EXPECT_EQ(pt.get<std::string>("TMVA_method"), "Method");
47 EXPECT_EQ(pt.get<std::string>("TMVA_type"), "Type");
48 EXPECT_EQ(pt.get<std::string>("TMVA_config"), "Config");
49 EXPECT_EQ(pt.get<std::string>("TMVA_factoryOption"), "FactoryOption");
50 EXPECT_EQ(pt.get<std::string>("TMVA_prepareOption"), "PrepareOption");
51 EXPECT_EQ(pt.get<std::string>("TMVA_workingDirectory"), "WorkingDirectory");
52 EXPECT_EQ(pt.get<std::string>("TMVA_prefix"), "Prefix");
53
54 MVA::TMVAOptions specific_options2;
55 specific_options2.load(pt);
56
57 EXPECT_EQ(specific_options2.m_method, "Method");
58 EXPECT_EQ(specific_options2.m_type, "Type");
59 EXPECT_EQ(specific_options2.m_config, "Config");
60 EXPECT_EQ(specific_options2.m_factoryOption, "FactoryOption");
61 EXPECT_EQ(specific_options2.m_prepareOption, "PrepareOption");
62 EXPECT_EQ(specific_options2.m_workingDirectory, "WorkingDirectory");
63 EXPECT_EQ(specific_options2.m_prefix, "Prefix");
64
65 MVA::TMVAOptionsClassification specific_classification_options;
66 EXPECT_EQ(specific_classification_options.transform2probability, true);
67 EXPECT_EQ(specific_classification_options.m_factoryOption, "!V:!Silent:Color:DrawProgressBar:AnalysisType=Classification");
68
69 specific_classification_options.transform2probability = false;
70 boost::property_tree::ptree pt_classification;
71 specific_classification_options.save(pt_classification);
72 EXPECT_EQ(pt_classification.get<bool>("TMVA_transform2probability"), false);
73
74 MVA::TMVAOptionsClassification specific_classification_options2;
75 specific_classification_options2.load(pt_classification);
76 EXPECT_EQ(specific_classification_options.transform2probability, false);
77
78 MVA::TMVAOptionsRegression specific_regression_options;
79 EXPECT_EQ(specific_regression_options.m_factoryOption, "!V:!Silent:Color:DrawProgressBar:AnalysisType=Regression");
80
81 EXPECT_EQ(specific_classification_options.getMethod(), std::string("TMVAClassification"));
82 EXPECT_EQ(specific_regression_options.getMethod(), std::string("TMVARegression"));
83
84 // Test if po::options_description is created without crashing
85 auto description = specific_options.getDescription();
86 EXPECT_EQ(description.options().size(), 6);
87
88 auto description_reg = specific_regression_options.getDescription();
89 EXPECT_EQ(description_reg.options().size(), 6);
90
91 auto description_cls = specific_classification_options.getDescription();
92 EXPECT_EQ(description_cls.options().size(), 7);
93
94 // Check for B2ERROR and throw if version is wrong
95 // we try with version 100, surely we will never reach this!
96 pt.put("TMVA_version", 100);
97 try {
98 EXPECT_B2ERROR(specific_options2.load(pt));
99 } catch (...) {
100
101 }
102 EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
103
104 }
105
106 class TestClassificationDataset : public MVA::Dataset {
107 public:
108 explicit TestClassificationDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
109 {
110 m_input = {0.0};
111 m_target = 0.0;
112 m_isSignal = false;
113 m_weight = 1.0;
114 }
115
116 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
117 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
118 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
119 void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
120 float getSignalFraction() override { return 0.1; };
121 std::vector<float> getFeature(unsigned int) override { return m_data; }
122
123 std::vector<float> m_data;
124
125 };
126
127
128 TEST(TMVATest, TMVAClassificationInterface)
129 {
131 interface;
132
133 MVA::GeneralOptions general_options;
134 general_options.m_variables = {"A"};
135 general_options.m_target_variable = "Target";
136 MVA::TMVAOptionsClassification specific_options;
137 specific_options.m_prepareOption = "SplitMode=block:!V";
138 specific_options.transform2probability = false;
139 specific_options.m_config =
140 "!H:!V:NTrees=400:BoostType=Grad:Shrinkage=0.1:nCuts=10:MaxDepth=3:IgnoreNegWeightsInTraining:MinNodeSize=20";
141 TestClassificationDataset dataset({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
142 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
143 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
144 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
145 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
146 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
147 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
148 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0
149 });
150
151 auto teacher = interface.getTeacher(general_options, specific_options);
152 auto weightfile = teacher->train(dataset);
153
154 auto expert = interface.getExpert();
155 expert->load(weightfile);
156 auto probabilities = expert->apply(dataset);
157 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
158 for (unsigned int i = 0; i < 24; ++i) {
159 EXPECT_LE(probabilities[i], 0.1);
160 EXPECT_GE(probabilities[i], -0.1);
161 }
162 for (unsigned int i = 24; i < 48; i += 2) {
163 EXPECT_LE(probabilities[i], -0.8);
164 EXPECT_GE(probabilities[i + 1], 0.8);
165 }
166 for (unsigned int i = 48; i < 72; ++i) {
167 EXPECT_LE(probabilities[i], 0.1);
168 EXPECT_GE(probabilities[i], -0.1);
169 }
170 for (unsigned int i = 72; i < 96; i += 2) {
171 EXPECT_LE(probabilities[i], -0.8);
172 EXPECT_GE(probabilities[i + 1], 0.8);
173 }
174 }
175
176
177 class TestRegressionDataset : public MVA::Dataset {
178 public:
179 explicit TestRegressionDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
180 {
181 m_input = {0.0};
182 m_target = 0.0;
183 m_isSignal = false;
184 m_weight = 1.0;
185 }
186
187 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
188 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
189 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
190 void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = static_cast<float>((static_cast<int>(iEvent % 48) - 24) / 4) / 24.0;};
191 float getSignalFraction() override { return 0.0; };
192 std::vector<float> getFeature(unsigned int) override { return m_data; }
193
194 std::vector<float> m_data;
195
196 };
197
198 TEST(TMVATest, TMVARegressionInterface)
199 {
201
202 MVA::GeneralOptions general_options;
203 general_options.m_variables = {"A"};
204 general_options.m_target_variable = "Target";
205 MVA::TMVAOptionsRegression specific_options;
206 specific_options.m_prepareOption = "SplitMode=block:!V";
207 specific_options.m_config = "!H:!V:NTrees=200::BoostType=Grad:Shrinkage=0.1:nCuts=24:MaxDepth=3";
208 //specific_options.config = "nCuts=120:NTrees=20:MaxDepth=4:BoostType=AdaBoostR2:SeparationType=RegressionVariance:MinNodeSize=10";
209 TestRegressionDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0,
210 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0,
211 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0,
212 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0,
213 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0,
214 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0,
215 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0,
216 10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0
217 });
218
219 auto teacher = interface.getTeacher(general_options, specific_options);
220 auto weightfile = teacher->train(dataset);
221
222 auto expert = interface.getExpert();
223 expert->load(weightfile);
224 auto values = expert->apply(dataset);
225 EXPECT_EQ(values.size(), dataset.getNumberOfEvents());
226 for (unsigned int i = 0; i < 96; i += 4) {
227 float r = static_cast<float>((static_cast<int>(i % 48) - 24) / 4) / 24.0;
228 EXPECT_LE(values[i], r + 0.05);
229 EXPECT_GE(values[i], r - 0.05);
230 EXPECT_LE(values[i + 1], r + 0.05);
231 EXPECT_GE(values[i + 1], r - 0.05);
232 EXPECT_LE(values[i + 2], r + 0.05);
233 EXPECT_GE(values[i + 2], r - 0.05);
234 EXPECT_LE(values[i + 3], r + 0.05);
235 EXPECT_GE(values[i + 3], r - 0.05);
236 }
237
238 }
239
240 TEST(TMVATest, WeightfilesAreReadCorrectly)
241 {
243 interface;
244
245 MVA::GeneralOptions general_options;
246 general_options.m_variables = {"M", "p", "pt"};
247 MVA::MultiDataset dataset(general_options, {{1.835127, 1.179507, 1.164944},
248 {1.873689, 1.881940, 1.843310},
249 {1.863657, 1.774831, 1.753773},
250 {1.858293, 1.605311, 0.631336},
251 {1.837129, 1.575739, 1.490166},
252 {1.811395, 1.524029, 0.565220}
253 },
254 {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
255
256 auto expert = interface.getExpert();
257
258 auto weightfile = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/TMVA.xml"));
259 expert->load(weightfile);
260 auto probabilities = expert->apply(dataset);
261 EXPECT_NEAR(probabilities[0], 0.098980136215686798, 0.0001);
262 EXPECT_NEAR(probabilities[1], 0.35516414046287537, 0.0001);
263 EXPECT_NEAR(probabilities[2], 0.066082566976547241, 0.0001);
264 EXPECT_NEAR(probabilities[3], 0.18826344609260559, 0.0001);
265 EXPECT_NEAR(probabilities[4], 0.10691597312688828, 0.0001);
266 EXPECT_NEAR(probabilities[5], 1.4245844629813542e-13, 0.0001);
267 }
268
269}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition Dataset.h:33
General options which are shared by all MVA trainings.
Definition Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition Interface.h:125
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition Interface.h:116
Wraps the data of a multiple event into a Dataset.
Definition Dataset.h:186
Options for the TMVA Classification MVA method.
Definition TMVA.h:80
virtual std::string getMethod() const override
Return method name.
Definition TMVA.h:112
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition TMVA.cc:81
bool transform2probability
Transform output of method to a probability.
Definition TMVA.h:115
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition TMVA.cc:69
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition TMVA.cc:75
Options for the TMVA Regression MVA method.
Definition TMVA.h:166
virtual std::string getMethod() const override
Return method name.
Definition TMVA.h:181
Options for the TMVA MVA method.
Definition TMVA.h:34
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition TMVA.h:72
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition TMVA.h:74
std::string m_config
TMVA config string for the chosen method.
Definition TMVA.h:66
std::string m_method
tmva method name
Definition TMVA.h:60
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition TMVA.cc:55
std::string m_factoryOption
Factory options passed to tmva factory.
Definition TMVA.h:71
std::string m_type
tmva method type
Definition TMVA.h:61
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition TMVA.h:73
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition TMVA.cc:27
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Abstract base class for different kinds of events.