9 #include <mva/methods/RegressionFastBDT.h>
10 #include <mva/interface/Interface.h>
11 #include <mva/interface/Dataset.h>
13 #include <gtest/gtest.h>
20 explicit TestDataset(
const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
28 [[nodiscard]]
unsigned int getNumberOfFeatures()
const override {
return 1; }
29 [[nodiscard]]
unsigned int getNumberOfSpectators()
const override {
return 0; }
30 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
31 void loadEvent(
unsigned int iEvent)
override
33 m_input[0] = m_data[iEvent]; m_target = 1.0 * iEvent / 10; m_isSignal = m_target == 1;
35 float getSignalFraction()
override {
return 0.1; };
36 std::vector<float> getFeature(
unsigned int)
override {
return m_data; }
38 std::vector<float> m_data;
42 TEST(RegressionFastBDTTest, RegressionFastBDTInterface)
47 general_options.m_variables = {
"A"};
49 specific_options.setMaximalBinNumber(3);
50 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0});
52 auto teacher = interface.
getTeacher(general_options, specific_options);
53 auto weightfile = teacher->train(dataset);
56 expert->load(weightfile);
57 auto probabilities = expert->apply(dataset);
58 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
59 EXPECT_EQ(probabilities.size(), 10);
61 for (
unsigned int i = 0; i < 5; ++i) {
62 EXPECT_LE(probabilities[i], 0.8);
64 for (
unsigned int i = 5; i < 10; ++i) {
65 EXPECT_GE(probabilities[i], 0.2);
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Explicit template specification for FastBDTs for regression options.
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.