Belle II Software  release-05-01-25
test_RegressionFastBDT.cc
1 /* BASF2 (Belle Analysis Framework 2) *
2  * Copyright(C) 2016 - Belle II Collaboration *
3  * *
4  * Author: The Belle II Collaboration *
5  * Contributors: Thomas Keck *
6  * *
7  * This software is provided "as is" without any warranty. *
8  **************************************************************************/
9 
10 #include <mva/methods/RegressionFastBDT.h>
11 #include <mva/interface/Interface.h>
12 #include <mva/interface/Dataset.h>
13 
14 #include <gtest/gtest.h>
15 
16 using namespace Belle2;
17 
18 namespace {
19  class TestDataset : public MVA::Dataset {
20  public:
21  explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
22  {
23  m_input = {0.0};
24  m_target = 0.0;
25  m_isSignal = false;
26  m_weight = 1.0;
27  }
28 
29  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
30  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
31  [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
32  void loadEvent(unsigned int iEvent) override
33  {
34  m_input[0] = m_data[iEvent]; m_target = 1.0 * iEvent / 10; m_isSignal = m_target == 1;
35  };
36  float getSignalFraction() override { return 0.1; };
37  std::vector<float> getFeature(unsigned int) override { return m_data; }
38 
39  std::vector<float> m_data;
40  };
41 
42 
43  TEST(RegressionFastBDTTest, RegressionFastBDTInterface)
44  {
46 
47  MVA::GeneralOptions general_options;
48  general_options.m_variables = {"A"};
49  MVA::RegressionFastBDTOptions specific_options;
50  specific_options.setMaximalBinNumber(3);
51  TestDataset dataset({1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0});
52 
53  auto teacher = interface.getTeacher(general_options, specific_options);
54  auto weightfile = teacher->train(dataset);
55 
56  auto expert = interface.getExpert();
57  expert->load(weightfile);
58  auto probabilities = expert->apply(dataset);
59  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
60  EXPECT_EQ(probabilities.size(), 10);
61 
62  for (unsigned int i = 0; i < 5; ++i) {
63  EXPECT_LE(probabilities[i], 0.8);
64  }
65  for (unsigned int i = 5; i < 10; ++i) {
66  EXPECT_GE(probabilities[i], 0.2);
67  }
68  }
69 }
Belle2::MVA::Dataset
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:34
Belle2::MVA::Interface::getTeacher
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:119
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::RegressionFastBDTOptions
Explicit template specification for FastBDTs for regression options.
Definition: RegressionFastBDT.h:32
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::TEST
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Definition: utilityFunctions.cc:18
Belle2::MVA::Interface::getExpert
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:128
Belle2::MVA::Interface
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:101