Belle II Software light-2405-quaxo
FastBDTTeacher Class Reference

Teacher for the FastBDT MVA method. More...

#include <FastBDT.h>

Inheritance diagram for FastBDTTeacher:
Collaboration diagram for FastBDTTeacher:

Public Member Functions

 FastBDTTeacher (const GeneralOptions &general_options, const FastBDTOptions &specific_options)
 Constructs a new teacher using the GeneralOptions and specific options of this training.
 
virtual Weightfile train (Dataset &training_data) const override
 Train a mva method using the given dataset returning a Weightfile.
 

Protected Attributes

GeneralOptions m_general_options
 GeneralOptions containing all shared options.
 

Private Attributes

FastBDTOptions m_specific_options
 Method specific options.
 

Detailed Description

Teacher for the FastBDT MVA method.

Definition at line 98 of file FastBDT.h.

Constructor & Destructor Documentation

◆ FastBDTTeacher()

FastBDTTeacher ( const GeneralOptions general_options,
const FastBDTOptions specific_options 
)

Constructs a new teacher using the GeneralOptions and specific options of this training.

Parameters
general_optionsdefining all shared options
specific_optionsdefininf all method specific options

Definition at line 156 of file FastBDT.cc.

157 : Teacher(general_options),
158 m_specific_options(specific_options) { }
FastBDTOptions m_specific_options
Method specific options.
Definition: FastBDT.h:115
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
Definition: Teacher.cc:18

Member Function Documentation

◆ train()

Weightfile train ( Dataset training_data) const
overridevirtual

Train a mva method using the given dataset returning a Weightfile.

Parameters
training_dataused to train the method

Implements Teacher.

Definition at line 160 of file FastBDT.cc.

161 {
162
163 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
164#if FastBDT_VERSION_MAJOR >= 4
165 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
166#else
167 // Deactivate support for spectators below version 4!
168 unsigned int numberOfSpectators = 0;
169#endif
170
171 // FastBDT Version 4 has a simplified interface with a sklearn style Classifier
172#if FastBDT_VERSION_MAJOR >= 5
173 if (m_specific_options.m_individual_nCuts.size() != 0
174 and m_specific_options.m_individual_nCuts.size() != numberOfFeatures + numberOfSpectators) {
175 B2ERROR("You provided individual nCut values for each feature and spectator, but the total number of provided cuts is not same as as the total number of features and spectators.");
176 }
177
178 std::vector<bool> individualPurityTransformation = m_specific_options.m_individualPurityTransformation;
179 if (m_specific_options.m_purityTransformation) {
180 if (individualPurityTransformation.size() == 0) {
181 for (unsigned int i = 0; i < numberOfFeatures; ++i) {
182 individualPurityTransformation.push_back(true);
183 }
184 }
185 }
186
187 std::vector<unsigned int> individual_nCuts = m_specific_options.m_individual_nCuts;
188 if (individual_nCuts.size() == 0) {
189 for (unsigned int i = 0; i < numberOfFeatures + numberOfSpectators; ++i) {
190 individual_nCuts.push_back(m_specific_options.m_nCuts);
191 }
192 }
193
194 FastBDT::Classifier classifier(m_specific_options.m_nTrees, m_specific_options.m_nLevels, individual_nCuts,
196 m_specific_options.m_sPlot, m_specific_options.m_flatnessLoss, individualPurityTransformation,
197 numberOfSpectators, true);
198
199 std::vector<std::vector<float>> X(numberOfFeatures + numberOfSpectators);
200 const auto& y = training_data.getSignals();
201 if (not isValidSignal(y)) {
202 B2FATAL("The training data is not valid. It only contains one class instead of two.");
203 }
204 const auto& w = training_data.getWeights();
205 for (unsigned int i = 0; i < numberOfFeatures; ++i) {
206 X[i] = training_data.getFeature(i);
207 }
208 for (unsigned int i = 0; i < numberOfSpectators; ++i) {
209 X[i + numberOfFeatures] = training_data.getSpectator(i);
210 }
211 classifier.fit(X, y, w);
212#else
213 const auto& y = training_data.getSignals();
214 if (not isValidSignal(y)) {
215 B2FATAL("The training data is not valid. It only contains one class instead of two.");
216 }
217 std::vector<FastBDT::FeatureBinning<float>> featureBinnings;
218 std::vector<unsigned int> nBinningLevels;
219 for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
220 auto feature = training_data.getFeature(iFeature);
221
222 unsigned int nCuts = m_specific_options.m_nCuts;
223#if FastBDT_VERSION_MAJOR >= 3
224 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature));
225#else
226 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature.begin(), feature.end()));
227#endif
228 nBinningLevels.push_back(nCuts);
229 }
230
231 for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
232 auto feature = training_data.getSpectator(iSpectator);
233
234 unsigned int nCuts = m_specific_options.m_nCuts;
235#if FastBDT_VERSION_MAJOR >= 3
236 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature));
237#else
238 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature.begin(), feature.end()));
239#endif
240 nBinningLevels.push_back(nCuts);
241 }
242
243 unsigned int numberOfEvents = training_data.getNumberOfEvents();
244 if (numberOfEvents > 5e+6) {
245 B2WARNING("Number of events for training exceeds 5 million. FastBDT performance starts getting worse when the number reaches O(10^7).");
246 }
247
248#if FastBDT_VERSION_MAJOR >= 4
249 FastBDT::EventSample eventSample(numberOfEvents, numberOfFeatures, numberOfSpectators, nBinningLevels);
250#else
251 FastBDT::EventSample eventSample(numberOfEvents, numberOfFeatures, nBinningLevels);
252#endif
253 std::vector<unsigned int> bins(numberOfFeatures + numberOfSpectators);
254 for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
255 training_data.loadEvent(iEvent);
256 for (unsigned int iFeature = 0; iFeature < numberOfFeatures + numberOfSpectators; ++iFeature) {
257 bins[iFeature] = featureBinnings[iFeature].ValueToBin(training_data.m_input[iFeature]);
258 }
259 for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
260 bins[iSpectator + numberOfFeatures] = featureBinnings[iSpectator + numberOfFeatures].ValueToBin(
261 training_data.m_spectators[iSpectator]);
262 }
263 eventSample.AddEvent(bins, training_data.m_weight, training_data.m_isSignal);
264 }
265
268#if FastBDT_VERSION_MAJOR >= 3
269 FastBDT::Forest<float> forest(dt.GetShrinkage(), dt.GetF0(), true);
270#else
271 FastBDT::Forest forest(dt.GetShrinkage(), dt.GetF0());
272#endif
273 for (auto t : dt.GetForest()) {
274#if FastBDT_VERSION_MAJOR >= 3
275 auto tree = FastBDT::removeFeatureBinningTransformationFromTree(t, featureBinnings);
276 forest.AddTree(tree);
277#else
278 forest.AddTree(t);
279#endif
280 }
281
282#endif
283
284
285 Weightfile weightfile;
286 std::string custom_weightfile = weightfile.generateFileName();
287 std::fstream file(custom_weightfile, std::ios_base::out | std::ios_base::trunc);
288
289#if FastBDT_VERSION_MAJOR >= 5
290 file << classifier << std::endl;
291#else
292#if FastBDT_VERSION_MAJOR >= 3
293 file << forest << std::endl;
294#else
295 file << featureBinnings << std::endl;
296 file << forest << std::endl;
297#endif
298#endif
299 file.close();
300
301 weightfile.addOptions(m_general_options);
302 weightfile.addOptions(m_specific_options);
303 weightfile.addFile("FastBDT_Weightfile", custom_weightfile);
304 weightfile.addSignalFraction(training_data.getSignalFraction());
305
306 std::map<std::string, float> importance;
307#if FastBDT_VERSION_MAJOR >= 5
308 for (auto& pair : classifier.GetVariableRanking()) {
309 importance[m_general_options.m_variables[pair.first]] = pair.second;
310 }
311#else
312 for (auto& pair : forest.GetVariableRanking()) {
313 importance[m_general_options.m_variables[pair.first]] = pair.second;
314 }
315#endif
316 weightfile.addFeatureImportance(importance);
317
318 return weightfile;
319
320 }
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:82
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:81
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:80
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:79
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:78
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49

Member Data Documentation

◆ m_general_options

GeneralOptions m_general_options
protectedinherited

GeneralOptions containing all shared options.

Definition at line 49 of file Teacher.h.

◆ m_specific_options

FastBDTOptions m_specific_options
private

Method specific options.

Definition at line 115 of file FastBDT.h.


The documentation for this class was generated from the following files: