Belle II Software  release-08-01-10
test_TMVA.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/TMVA.h>
10 #include <mva/interface/Interface.h>
11 #include <framework/utilities/FileSystem.h>
12 #include <framework/utilities/TestHelpers.h>
13 
14 #include <gtest/gtest.h>
15 
16 using namespace Belle2;
17 
18 namespace {
19 
20  TEST(TMVATest, TMVAOptions)
21  {
22  MVA::TMVAOptions specific_options;
23 
24  //EXPECT_EQ(specific_options.method, "FastBDT");
25  //EXPECT_EQ(specific_options.type, "Plugins");
26  //EXPECT_EQ(specific_options.config, "!H:!V:CreateMVAPdfs:NTrees=400:Shrinkage=0.10:RandRatio=0.5:NCutLevel=8:NTreeLayers=3");
27  EXPECT_EQ(specific_options.m_method, "BDT");
28  EXPECT_EQ(specific_options.m_type, "BDT");
29  EXPECT_EQ(specific_options.m_config,
30  "!H:!V:CreateMVAPdfs:NTrees=400:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=1024:MaxDepth=3:IgnoreNegWeightsInTraining");
31  EXPECT_EQ(specific_options.m_factoryOption, "!V:!Silent:Color:DrawProgressBar");
32  EXPECT_EQ(specific_options.m_prepareOption, "SplitMode=random:!V");
33  EXPECT_EQ(specific_options.m_workingDirectory, "");
34  EXPECT_EQ(specific_options.m_prefix, "TMVA");
35 
36  specific_options.m_method = "Method";
37  specific_options.m_type = "Type";
38  specific_options.m_config = "Config";
39  specific_options.m_factoryOption = "FactoryOption";
40  specific_options.m_prepareOption = "PrepareOption";
41  specific_options.m_workingDirectory = "WorkingDirectory";
42  specific_options.m_prefix = "Prefix";
43 
44  boost::property_tree::ptree pt;
45  specific_options.save(pt);
46  EXPECT_EQ(pt.get<std::string>("TMVA_method"), "Method");
47  EXPECT_EQ(pt.get<std::string>("TMVA_type"), "Type");
48  EXPECT_EQ(pt.get<std::string>("TMVA_config"), "Config");
49  EXPECT_EQ(pt.get<std::string>("TMVA_factoryOption"), "FactoryOption");
50  EXPECT_EQ(pt.get<std::string>("TMVA_prepareOption"), "PrepareOption");
51  EXPECT_EQ(pt.get<std::string>("TMVA_workingDirectory"), "WorkingDirectory");
52  EXPECT_EQ(pt.get<std::string>("TMVA_prefix"), "Prefix");
53 
54  MVA::TMVAOptions specific_options2;
55  specific_options2.load(pt);
56 
57  EXPECT_EQ(specific_options2.m_method, "Method");
58  EXPECT_EQ(specific_options2.m_type, "Type");
59  EXPECT_EQ(specific_options2.m_config, "Config");
60  EXPECT_EQ(specific_options2.m_factoryOption, "FactoryOption");
61  EXPECT_EQ(specific_options2.m_prepareOption, "PrepareOption");
62  EXPECT_EQ(specific_options2.m_workingDirectory, "WorkingDirectory");
63  EXPECT_EQ(specific_options2.m_prefix, "Prefix");
64 
65  MVA::TMVAOptionsClassification specific_classification_options;
66  EXPECT_EQ(specific_classification_options.transform2probability, true);
67  EXPECT_EQ(specific_classification_options.m_factoryOption, "!V:!Silent:Color:DrawProgressBar:AnalysisType=Classification");
68 
69  specific_classification_options.transform2probability = false;
70  boost::property_tree::ptree pt_classification;
71  specific_classification_options.save(pt_classification);
72  EXPECT_EQ(pt_classification.get<bool>("TMVA_transform2probability"), false);
73 
74  MVA::TMVAOptionsClassification specific_classification_options2;
75  specific_classification_options2.load(pt_classification);
76  EXPECT_EQ(specific_classification_options.transform2probability, false);
77 
78  MVA::TMVAOptionsRegression specific_regression_options;
79  EXPECT_EQ(specific_regression_options.m_factoryOption, "!V:!Silent:Color:DrawProgressBar:AnalysisType=Regression");
80 
81  EXPECT_EQ(specific_classification_options.getMethod(), std::string("TMVAClassification"));
82  EXPECT_EQ(specific_regression_options.getMethod(), std::string("TMVARegression"));
83 
84  // Test if po::options_description is created without crashing
85  auto description = specific_options.getDescription();
86  EXPECT_EQ(description.options().size(), 6);
87 
88  auto description_reg = specific_regression_options.getDescription();
89  EXPECT_EQ(description_reg.options().size(), 6);
90 
91  auto description_cls = specific_classification_options.getDescription();
92  EXPECT_EQ(description_cls.options().size(), 7);
93 
94  // Check for B2ERROR and throw if version is wrong
95  // we try with version 100, surely we will never reach this!
96  pt.put("TMVA_version", 100);
97  try {
98  EXPECT_B2ERROR(specific_options2.load(pt));
99  } catch (...) {
100 
101  }
102  EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
103 
104  }
105 
106  class TestClassificationDataset : public MVA::Dataset {
107  public:
108  explicit TestClassificationDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
109  {
110  m_input = {0.0};
111  m_target = 0.0;
112  m_isSignal = false;
113  m_weight = 1.0;
114  }
115 
116  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
117  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
118  [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
119  void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
120  float getSignalFraction() override { return 0.1; };
121  std::vector<float> getFeature(unsigned int) override { return m_data; }
122 
123  std::vector<float> m_data;
124 
125  };
126 
127 
128  TEST(TMVATest, TMVAClassificationInterface)
129  {
131  interface;
132 
133  MVA::GeneralOptions general_options;
134  general_options.m_variables = {"A"};
135  general_options.m_target_variable = "Target";
136  MVA::TMVAOptionsClassification specific_options;
137  specific_options.m_prepareOption = "SplitMode=block:!V";
138  specific_options.transform2probability = false;
139  specific_options.m_config =
140  "!H:!V:CreateMVAPdfs:NTrees=400:BoostType=Grad:Shrinkage=0.1:nCuts=10:MaxDepth=3:IgnoreNegWeightsInTraining:MinNodeSize=20";
141  TestClassificationDataset dataset({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
142  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
143  2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
144  2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
145  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
146  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
147  2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
148  2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0
149  });
150 
151  auto teacher = interface.getTeacher(general_options, specific_options);
152  auto weightfile = teacher->train(dataset);
153 
154  auto expert = interface.getExpert();
155  expert->load(weightfile);
156  auto probabilities = expert->apply(dataset);
157  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
158  for (unsigned int i = 0; i < 24; ++i) {
159  EXPECT_LE(probabilities[i], 0.1);
160  EXPECT_GE(probabilities[i], -0.1);
161  }
162  for (unsigned int i = 24; i < 48; i += 2) {
163  EXPECT_LE(probabilities[i], -0.8);
164  EXPECT_GE(probabilities[i + 1], 0.8);
165  }
166  for (unsigned int i = 48; i < 72; ++i) {
167  EXPECT_LE(probabilities[i], 0.1);
168  EXPECT_GE(probabilities[i], -0.1);
169  }
170  for (unsigned int i = 72; i < 96; i += 2) {
171  EXPECT_LE(probabilities[i], -0.8);
172  EXPECT_GE(probabilities[i + 1], 0.8);
173  }
174 
175 
176  }
177 
178 
179  class TestRegressionDataset : public MVA::Dataset {
180  public:
181  explicit TestRegressionDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
182  {
183  m_input = {0.0};
184  m_target = 0.0;
185  m_isSignal = false;
186  m_weight = 1.0;
187  }
188 
189  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
190  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
191  [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
192  void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = static_cast<float>((static_cast<int>(iEvent % 48) - 24) / 4) / 24.0;};
193  float getSignalFraction() override { return 0.0; };
194  std::vector<float> getFeature(unsigned int) override { return m_data; }
195 
196  std::vector<float> m_data;
197 
198  };
199 
200  TEST(TMVATest, TMVARegressionInterface)
201  {
203 
204  MVA::GeneralOptions general_options;
205  general_options.m_variables = {"A"};
206  general_options.m_target_variable = "Target";
207  MVA::TMVAOptionsRegression specific_options;
208  specific_options.m_prepareOption = "SplitMode=block:!V";
209  specific_options.m_config = "!H:!V:NTrees=200::BoostType=Grad:Shrinkage=0.1:nCuts=24:MaxDepth=3";
210  //specific_options.config = "nCuts=120:NTrees=20:MaxDepth=4:BoostType=AdaBoostR2:SeparationType=RegressionVariance:MinNodeSize=10";
211  TestRegressionDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0,
212  4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0,
213  7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0,
214  10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0,
215  1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0,
216  4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0,
217  7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 9.0, 9.0, 9.0, 9.0,
218  10.0, 10.0, 10.0, 10.0, 11.0, 11.0, 11.0, 11.0, 12.0, 12.0, 12.0, 12.0
219  });
220 
221  auto teacher = interface.getTeacher(general_options, specific_options);
222  auto weightfile = teacher->train(dataset);
223 
224  auto expert = interface.getExpert();
225  expert->load(weightfile);
226  auto values = expert->apply(dataset);
227  EXPECT_EQ(values.size(), dataset.getNumberOfEvents());
228  for (unsigned int i = 0; i < 96; i += 4) {
229  float r = static_cast<float>((static_cast<int>(i % 48) - 24) / 4) / 24.0;
230  EXPECT_LE(values[i], r + 0.05);
231  EXPECT_GE(values[i], r - 0.05);
232  EXPECT_LE(values[i + 1], r + 0.05);
233  EXPECT_GE(values[i + 1], r - 0.05);
234  EXPECT_LE(values[i + 2], r + 0.05);
235  EXPECT_GE(values[i + 2], r - 0.05);
236  EXPECT_LE(values[i + 3], r + 0.05);
237  EXPECT_GE(values[i + 3], r - 0.05);
238  }
239 
240  }
241 
242  TEST(TMVATest, WeightfilesAreReadCorrectly)
243  {
245  interface;
246 
247  MVA::GeneralOptions general_options;
248  general_options.m_variables = {"M", "p", "pt"};
249  MVA::MultiDataset dataset(general_options, {{1.835127, 1.179507, 1.164944},
250  {1.873689, 1.881940, 1.843310},
251  {1.863657, 1.774831, 1.753773},
252  {1.858293, 1.605311, 0.631336},
253  {1.837129, 1.575739, 1.490166},
254  {1.811395, 1.524029, 0.565220}
255  },
256  {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
257 
258  auto expert = interface.getExpert();
259 
260  auto weightfile = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/TMVA.xml"));
261  expert->load(weightfile);
262  auto probabilities = expert->apply(dataset);
263  EXPECT_NEAR(probabilities[0], 0.098980136215686798, 0.0001);
264  EXPECT_NEAR(probabilities[1], 0.35516414046287537, 0.0001);
265  EXPECT_NEAR(probabilities[2], 0.066082566976547241, 0.0001);
266  EXPECT_NEAR(probabilities[3], 0.18826344609260559, 0.0001);
267  EXPECT_NEAR(probabilities[4], 0.10691597312688828, 0.0001);
268  EXPECT_NEAR(probabilities[5], 1.4245844629813542e-13, 0.0001);
269  }
270 
271 }
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:148
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:117
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:126
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
Options for the TMVA Classification MVA method.
Definition: TMVA.h:80
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:112
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:81
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:115
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:69
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:75
Options for the TMVA Regression MVA method.
Definition: TMVA.h:166
virtual std::string getMethod() const override
Return method name.
Definition: TMVA.h:181
Options for the TMVA MVA method.
Definition: TMVA.h:34
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:55
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_type
tmva method type
Definition: TMVA.h:61
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:27
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.