Belle II Software development
FastBDTTeacher Class Reference

Teacher for the FastBDT MVA method. More...

#include <FastBDT.h>

Inheritance diagram for FastBDTTeacher:
Teacher

Public Member Functions

 FastBDTTeacher (const GeneralOptions &general_options, const FastBDTOptions &specific_options)
 Constructs a new teacher using the GeneralOptions and specific options of this training.
 
virtual Weightfile train (Dataset &training_data) const override
 Train a mva method using the given dataset returning a Weightfile.
 

Protected Attributes

GeneralOptions m_general_options
 GeneralOptions containing all shared options.
 

Private Attributes

FastBDTOptions m_specific_options
 Method specific options.
 

Detailed Description

Teacher for the FastBDT MVA method.

Definition at line 80 of file FastBDT.h.

Constructor & Destructor Documentation

◆ FastBDTTeacher()

FastBDTTeacher ( const GeneralOptions general_options,
const FastBDTOptions specific_options 
)

Constructs a new teacher using the GeneralOptions and specific options of this training.

Parameters
general_optionsdefining all shared options
specific_optionsdefininf all method specific options

Definition at line 117 of file FastBDT.cc.

118 : Teacher(general_options),
119 m_specific_options(specific_options) { }
FastBDTOptions m_specific_options
Method specific options.
Definition: FastBDT.h:97
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
Definition: Teacher.cc:18

Member Function Documentation

◆ train()

Weightfile train ( Dataset training_data) const
overridevirtual

Train a mva method using the given dataset returning a Weightfile.

Parameters
training_dataused to train the method

Implements Teacher.

Definition at line 121 of file FastBDT.cc.

122 {
123 if (training_data.getNumberOfEvents() > 5e+6) {
124 B2WARNING("Number of events for training exceeds 5 million. FastBDT performance starts getting worse when the number reaches O(10^7).");
125 }
126
127 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
128 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
129
131 and m_specific_options.m_individual_nCuts.size() != numberOfFeatures + numberOfSpectators) {
132 B2ERROR("You provided individual nCut values for each feature and spectator, but the total number of provided cuts is not same as as the total number of features and spectators.");
133 }
134
135 std::vector<bool> individualPurityTransformation = m_specific_options.m_individualPurityTransformation;
137 if (individualPurityTransformation.size() == 0) {
138 for (unsigned int i = 0; i < numberOfFeatures; ++i) {
139 individualPurityTransformation.push_back(true);
140 }
141 }
142 }
143
144 std::vector<unsigned int> individual_nCuts = m_specific_options.m_individual_nCuts;
145 if (individual_nCuts.size() == 0) {
146 for (unsigned int i = 0; i < numberOfFeatures + numberOfSpectators; ++i) {
147 individual_nCuts.push_back(m_specific_options.m_nCuts);
148 }
149 }
150
151 FastBDT::Classifier classifier(m_specific_options.m_nTrees, m_specific_options.m_nLevels, individual_nCuts,
153 m_specific_options.m_sPlot, m_specific_options.m_flatnessLoss, individualPurityTransformation,
154 numberOfSpectators, true);
155
156 std::vector<std::vector<float>> X(numberOfFeatures + numberOfSpectators);
157 const auto& y = training_data.getSignals();
158 if (not isValidSignal(y)) {
159 B2FATAL("The training data is not valid. It only contains one class instead of two.");
160 }
161 const auto& w = training_data.getWeights();
162 for (unsigned int i = 0; i < numberOfFeatures; ++i) {
163 X[i] = training_data.getFeature(i);
164 }
165 for (unsigned int i = 0; i < numberOfSpectators; ++i) {
166 X[i + numberOfFeatures] = training_data.getSpectator(i);
167 }
168 classifier.fit(X, y, w);
169
170 Weightfile weightfile;
171 std::string custom_weightfile = weightfile.generateFileName();
172 std::fstream file(custom_weightfile, std::ios_base::out | std::ios_base::trunc);
173
174 file << classifier << std::endl;
175 file.close();
176
177 weightfile.addOptions(m_general_options);
178 weightfile.addOptions(m_specific_options);
179 weightfile.addFile("FastBDT_Weightfile", custom_weightfile);
180 weightfile.addSignalFraction(training_data.getSignalFraction());
181
182 std::map<std::string, float> importance;
183 for (auto& pair : classifier.GetVariableRanking()) {
184 importance[m_general_options.m_variables[pair.first]] = pair.second;
185 }
186 weightfile.addFeatureImportance(importance);
187
188 return weightfile;
189
190 }
std::vector< unsigned int > m_individual_nCuts
Number of cut Levels = log_2(Number of Cuts) for each provided feature.
Definition: FastBDT.h:68
bool m_sPlot
Activates sPlot sampling.
Definition: FastBDT.h:70
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:66
double m_flatnessLoss
Flatness Loss constant.
Definition: FastBDT.h:69
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:65
bool m_purityTransformation
Activates purity transformation globally for all features.
Definition: FastBDT.h:71
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:64
std::vector< bool > m_individualPurityTransformation
Vector which decided for each feature individually if the purity transformation should be used.
Definition: FastBDT.h:73
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:63
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49

Member Data Documentation

◆ m_general_options

GeneralOptions m_general_options
protectedinherited

GeneralOptions containing all shared options.

Definition at line 49 of file Teacher.h.

◆ m_specific_options

FastBDTOptions m_specific_options
private

Method specific options.

Definition at line 97 of file FastBDT.h.


The documentation for this class was generated from the following files: