Belle II Software light-2406-ragdoll
FastBDTTeacher Class Reference

Teacher for the FastBDT MVA method. More...

#include <FastBDT.h>

Inheritance diagram for FastBDTTeacher:
Collaboration diagram for FastBDTTeacher:

Public Member Functions

 FastBDTTeacher (const GeneralOptions &general_options, const FastBDTOptions &specific_options)
 Constructs a new teacher using the GeneralOptions and specific options of this training.
 
virtual Weightfile train (Dataset &training_data) const override
 Train a mva method using the given dataset returning a Weightfile.
 

Protected Attributes

GeneralOptions m_general_options
 GeneralOptions containing all shared options.
 

Private Attributes

FastBDTOptions m_specific_options
 Method specific options.
 

Detailed Description

Teacher for the FastBDT MVA method.

Definition at line 80 of file FastBDT.h.

Constructor & Destructor Documentation

◆ FastBDTTeacher()

FastBDTTeacher ( const GeneralOptions general_options,
const FastBDTOptions specific_options 
)

Constructs a new teacher using the GeneralOptions and specific options of this training.

Parameters
general_optionsdefining all shared options
specific_optionsdefininf all method specific options

Definition at line 117 of file FastBDT.cc.

118 : Teacher(general_options),
119 m_specific_options(specific_options) { }
FastBDTOptions m_specific_options
Method specific options.
Definition: FastBDT.h:97
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
Definition: Teacher.cc:18

Member Function Documentation

◆ train()

Weightfile train ( Dataset training_data) const
overridevirtual

Train a mva method using the given dataset returning a Weightfile.

Parameters
training_dataused to train the method

Implements Teacher.

Definition at line 121 of file FastBDT.cc.

122 {
123
124 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
125 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
126
128 and m_specific_options.m_individual_nCuts.size() != numberOfFeatures + numberOfSpectators) {
129 B2ERROR("You provided individual nCut values for each feature and spectator, but the total number of provided cuts is not same as as the total number of features and spectators.");
130 }
131
132 std::vector<bool> individualPurityTransformation = m_specific_options.m_individualPurityTransformation;
134 if (individualPurityTransformation.size() == 0) {
135 for (unsigned int i = 0; i < numberOfFeatures; ++i) {
136 individualPurityTransformation.push_back(true);
137 }
138 }
139 }
140
141 std::vector<unsigned int> individual_nCuts = m_specific_options.m_individual_nCuts;
142 if (individual_nCuts.size() == 0) {
143 for (unsigned int i = 0; i < numberOfFeatures + numberOfSpectators; ++i) {
144 individual_nCuts.push_back(m_specific_options.m_nCuts);
145 }
146 }
147
148 FastBDT::Classifier classifier(m_specific_options.m_nTrees, m_specific_options.m_nLevels, individual_nCuts,
150 m_specific_options.m_sPlot, m_specific_options.m_flatnessLoss, individualPurityTransformation,
151 numberOfSpectators, true);
152
153 std::vector<std::vector<float>> X(numberOfFeatures + numberOfSpectators);
154 const auto& y = training_data.getSignals();
155 if (not isValidSignal(y)) {
156 B2FATAL("The training data is not valid. It only contains one class instead of two.");
157 }
158 const auto& w = training_data.getWeights();
159 for (unsigned int i = 0; i < numberOfFeatures; ++i) {
160 X[i] = training_data.getFeature(i);
161 }
162 for (unsigned int i = 0; i < numberOfSpectators; ++i) {
163 X[i + numberOfFeatures] = training_data.getSpectator(i);
164 }
165 classifier.fit(X, y, w);
166
167 Weightfile weightfile;
168 std::string custom_weightfile = weightfile.generateFileName();
169 std::fstream file(custom_weightfile, std::ios_base::out | std::ios_base::trunc);
170
171 file << classifier << std::endl;
172 file.close();
173
174 weightfile.addOptions(m_general_options);
175 weightfile.addOptions(m_specific_options);
176 weightfile.addFile("FastBDT_Weightfile", custom_weightfile);
177 weightfile.addSignalFraction(training_data.getSignalFraction());
178
179 std::map<std::string, float> importance;
180 for (auto& pair : classifier.GetVariableRanking()) {
181 importance[m_general_options.m_variables[pair.first]] = pair.second;
182 }
183 weightfile.addFeatureImportance(importance);
184
185 return weightfile;
186
187 }
std::vector< unsigned int > m_individual_nCuts
Number of cut Levels = log_2(Number of Cuts) for each provided feature.
Definition: FastBDT.h:68
bool m_sPlot
Activates sPlot sampling.
Definition: FastBDT.h:70
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:66
double m_flatnessLoss
Flatness Loss constant.
Definition: FastBDT.h:69
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:65
bool m_purityTransformation
Activates purity transformation globally for all features.
Definition: FastBDT.h:71
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:64
std::vector< bool > m_individualPurityTransformation
Vector which decided for each feature individually if the purity transformation should be used.
Definition: FastBDT.h:73
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:63
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49

Member Data Documentation

◆ m_general_options

GeneralOptions m_general_options
protectedinherited

GeneralOptions containing all shared options.

Definition at line 49 of file Teacher.h.

◆ m_specific_options

FastBDTOptions m_specific_options
private

Method specific options.

Definition at line 97 of file FastBDT.h.


The documentation for this class was generated from the following files: