9#include <mva/methods/RegressionFastBDT.h>
10#include <mva/interface/Interface.h>
11#include <mva/interface/Dataset.h>
13#include <gtest/gtest.h>
20 explicit TestDataset(
const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
30 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
31 void loadEvent(
unsigned int iEvent)
override
33 m_input[0] = m_data[iEvent]; m_target = 1.0 * iEvent / 10; m_isSignal = m_target == 1;
36 std::vector<float>
getFeature(
unsigned int)
override {
return m_data; }
38 std::vector<float> m_data;
42 TEST(RegressionFastBDTTest, RegressionFastBDTInterface)
47 general_options.m_variables = {
"A"};
49 specific_options.setMaximalBinNumber(3);
50 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0});
52 auto teacher = interface.
getTeacher(general_options, specific_options);
53 auto weightfile = teacher->train(dataset);
56 expert->load(weightfile);
57 auto probabilities = expert->apply(dataset);
58 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
59 EXPECT_EQ(probabilities.size(), 10);
61 for (
unsigned int i = 0; i < 5; ++i) {
62 EXPECT_LE(probabilities[i], 0.8);
64 for (
unsigned int i = 5; i < 10; ++i) {
65 EXPECT_GE(probabilities[i], 0.2);
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Explicit template specification for FastBDTs for regression options.
Abstract base class for different kinds of events.