Belle II Software light-2406-ragdoll
test_RegressionFastBDT.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/RegressionFastBDT.h>
10#include <mva/interface/Interface.h>
11#include <mva/interface/Dataset.h>
12
13#include <gtest/gtest.h>
14
15using namespace Belle2;
16
17namespace {
18 class TestDataset : public MVA::Dataset {
19 public:
20 explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
21 {
22 m_input = {0.0};
23 m_target = 0.0;
24 m_isSignal = false;
25 m_weight = 1.0;
26 }
27
28 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
29 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
30 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
31 void loadEvent(unsigned int iEvent) override
32 {
33 m_input[0] = m_data[iEvent]; m_target = 1.0 * iEvent / 10; m_isSignal = m_target == 1;
34 };
35 float getSignalFraction() override { return 0.1; };
36 std::vector<float> getFeature(unsigned int) override { return m_data; }
37
38 std::vector<float> m_data;
39 };
40
41
42 TEST(RegressionFastBDTTest, RegressionFastBDTInterface)
43 {
45
46 MVA::GeneralOptions general_options;
47 general_options.m_variables = {"A"};
48 MVA::RegressionFastBDTOptions specific_options;
49 specific_options.setMaximalBinNumber(3);
50 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0});
51
52 auto teacher = interface.getTeacher(general_options, specific_options);
53 auto weightfile = teacher->train(dataset);
54
55 auto expert = interface.getExpert();
56 expert->load(weightfile);
57 auto probabilities = expert->apply(dataset);
58 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
59 EXPECT_EQ(probabilities.size(), 10);
60
61 for (unsigned int i = 0; i < 5; ++i) {
62 EXPECT_LE(probabilities[i], 0.8);
63 }
64 for (unsigned int i = 5; i < 10; ++i) {
65 EXPECT_GE(probabilities[i], 0.2);
66 }
67 }
68}
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:74
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:35
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:125
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:116
Explicit template specification for FastBDTs for regression options.
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24