Belle II Software  release-08-01-10
test_RegressionFastBDT.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/RegressionFastBDT.h>
10 #include <mva/interface/Interface.h>
11 #include <mva/interface/Dataset.h>
12 
13 #include <gtest/gtest.h>
14 
15 using namespace Belle2;
16 
17 namespace {
18  class TestDataset : public MVA::Dataset {
19  public:
20  explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
21  {
22  m_input = {0.0};
23  m_target = 0.0;
24  m_isSignal = false;
25  m_weight = 1.0;
26  }
27 
28  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
29  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
30  [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
31  void loadEvent(unsigned int iEvent) override
32  {
33  m_input[0] = m_data[iEvent]; m_target = 1.0 * iEvent / 10; m_isSignal = m_target == 1;
34  };
35  float getSignalFraction() override { return 0.1; };
36  std::vector<float> getFeature(unsigned int) override { return m_data; }
37 
38  std::vector<float> m_data;
39  };
40 
41 
42  TEST(RegressionFastBDTTest, RegressionFastBDTInterface)
43  {
45 
46  MVA::GeneralOptions general_options;
47  general_options.m_variables = {"A"};
48  MVA::RegressionFastBDTOptions specific_options;
49  specific_options.setMaximalBinNumber(3);
50  TestDataset dataset({1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0});
51 
52  auto teacher = interface.getTeacher(general_options, specific_options);
53  auto weightfile = teacher->train(dataset);
54 
55  auto expert = interface.getExpert();
56  expert->load(weightfile);
57  auto probabilities = expert->apply(dataset);
58  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
59  EXPECT_EQ(probabilities.size(), 10);
60 
61  for (unsigned int i = 0; i < 5; ++i) {
62  EXPECT_LE(probabilities[i], 0.8);
63  }
64  for (unsigned int i = 5; i < 10; ++i) {
65  EXPECT_GE(probabilities[i], 0.2);
66  }
67  }
68 }
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:117
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:126
Explicit template specification for FastBDTs for regression options.
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.