Belle II Software light-2405-quaxo
FastBDT.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/FastBDT.h>
10
11#include <framework/logging/Logger.h>
12#include <sstream>
13#include <vector>
14
15// Template specialization to fix NAN sort bug of FastBDT in up to Version 3.2
16#if FastBDT_VERSION_MAJOR <= 3 && FastBDT_VERSION_MINOR <= 2
17namespace FastBDT {
18 template<>
19 bool compareIncludingNaN(float i, float j)
20 {
21 if (std::isnan(i)) {
22 if (std::isnan(j)) {
23 // If both are NAN i is NOT smaller
24 return false;
25 } else {
26 // In all other cases i is smaller
27 return true;
28 }
29 }
30 // If j is NaN the following line will return false,
31 // which is fine in our case.
32 return i < j;
33 }
34}
35#endif
36
37namespace Belle2 {
42 namespace MVA {
43 bool isValidSignal(const std::vector<bool>& Signals)
44 {
45 const auto first = Signals.front();
46 for (const auto& value : Signals) {
47 if (value != first)
48 return true;
49 }
50 return false;
51 }
52
53 void FastBDTOptions::load(const boost::property_tree::ptree& pt)
54 {
55 int version = pt.get<int>("FastBDT_version");
56#if FastBDT_VERSION_MAJOR >= 5
57 if (version != 1 and version != 2) {
58 B2ERROR("Unknown weightfile version " << std::to_string(version));
59 throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
60 }
61#else
62 if (version != 1) {
63 B2ERROR("Unknown weightfile version " << std::to_string(version));
64 throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
65 }
66#endif
67 m_nTrees = pt.get<int>("FastBDT_nTrees");
68 m_nCuts = pt.get<int>("FastBDT_nCuts");
69 m_nLevels = pt.get<int>("FastBDT_nLevels");
70 m_shrinkage = pt.get<double>("FastBDT_shrinkage");
71 m_randRatio = pt.get<double>("FastBDT_randRatio");
72
73#if FastBDT_VERSION_MAJOR >= 5
74 if (version > 1) {
75
76 m_flatnessLoss = pt.get<double>("FastBDT_flatnessLoss");
77 m_sPlot = pt.get<bool>("FastBDT_sPlot");
78
79 unsigned int numberOfIndividualNCuts = pt.get<unsigned int>("FastBDT_number_individual_nCuts", 0);
80 m_individual_nCuts.resize(numberOfIndividualNCuts);
81 for (unsigned int i = 0; i < numberOfIndividualNCuts; ++i) {
82 m_individual_nCuts[i] = pt.get<unsigned int>(std::string("FastBDT_individual_nCuts") + std::to_string(i));
83 }
84
85 m_purityTransformation = pt.get<bool>("FastBDT_purityTransformation");
86 unsigned int numberOfIndividualPurityTransformation = pt.get<unsigned int>("FastBDT_number_individualPurityTransformation", 0);
87 m_individualPurityTransformation.resize(numberOfIndividualPurityTransformation);
88 for (unsigned int i = 0; i < numberOfIndividualPurityTransformation; ++i) {
89 m_individualPurityTransformation[i] = pt.get<bool>(std::string("FastBDT_individualPurityTransformation") + std::to_string(i));
90 }
91
92 } else {
93 m_flatnessLoss = -1.0;
94 m_sPlot = false;
95 }
96#endif
97 }
98
99 void FastBDTOptions::save(boost::property_tree::ptree& pt) const
100 {
101#if FastBDT_VERSION_MAJOR >= 5
102 pt.put("FastBDT_version", 2);
103#else
104 pt.put("FastBDT_version", 1);
105#endif
106 pt.put("FastBDT_nTrees", m_nTrees);
107 pt.put("FastBDT_nCuts", m_nCuts);
108 pt.put("FastBDT_nLevels", m_nLevels);
109 pt.put("FastBDT_shrinkage", m_shrinkage);
110 pt.put("FastBDT_randRatio", m_randRatio);
111#if FastBDT_VERSION_MAJOR >= 5
112 pt.put("FastBDT_flatnessLoss", m_flatnessLoss);
113 pt.put("FastBDT_sPlot", m_sPlot);
114 pt.put("FastBDT_number_individual_nCuts", m_individual_nCuts.size());
115 for (unsigned int i = 0; i < m_individual_nCuts.size(); ++i) {
116 pt.put(std::string("FastBDT_individual_nCuts") + std::to_string(i), m_individual_nCuts[i]);
117 }
118 pt.put("FastBDT_purityTransformation", m_purityTransformation);
119 pt.put("FastBDT_number_individualPurityTransformation", m_individualPurityTransformation.size());
120 for (unsigned int i = 0; i < m_individualPurityTransformation.size(); ++i) {
121 pt.put(std::string("FastBDT_individualPurityTransformation") + std::to_string(i), m_individualPurityTransformation[i]);
122 }
123#endif
124 }
125
126 po::options_description FastBDTOptions::getDescription()
127 {
128 po::options_description description("FastBDT options");
129 description.add_options()
130 ("nTrees", po::value<unsigned int>(&m_nTrees), "Number of trees in the forest. Reasonable values are between 10 and 1000")
131 ("nLevels", po::value<unsigned int>(&m_nLevels)->notifier(check_bounds<unsigned int>(0, 20, "nLevels")),
132 "Depth d of trees. The last layer of the tree will contain 2^d bins. Maximum is 20. Reasonable values are 2 and 6.")
133 ("shrinkage", po::value<double>(&m_shrinkage)->notifier(check_bounds<double>(0.0, 1.0, "shrinkage")),
134 "Shrinkage of the boosting algorithm. Reasonable values are between 0.01 and 1.0.")
135 ("nCutLevels", po::value<unsigned int>(&m_nCuts)->notifier(check_bounds<unsigned int>(0, 20, "nCutLevels")),
136 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12.")
137#if FastBDT_VERSION_MAJOR >= 5
138 ("individualNCutLevels", po::value<std::vector<unsigned int>>(&m_individual_nCuts)->multitoken()->notifier(
139 check_bounds_vector<unsigned int>(0, 20, "individualNCutLevels")),
140 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12. One value per feature (including spectators) should be provided, if parameter is not set the global value specified by nCutLevels is used for all features.")
141 ("sPlot", po::value<bool>(&m_sPlot),
142 "Since in sPlot each event enters twice, this option modifies the sampling algorithm so that the matching signal and background events are selected together.")
143 ("flatnessLoss", po::value<double>(&m_flatnessLoss),
144 "Activate Flatness Loss, all spectator variables are assumed to be variables in which the signal and background efficiency should be flat. negative values deactivates flatness loss.")
145 ("purityTransformation", po::value<bool>(&m_purityTransformation),
146 "Activates purity transformation on all features: Add the purity transformed of all features in addition to the training. This will double the number of features and slow down the inference considerably")
147 ("individualPurityTransformation", po::value<std::vector<bool>>(&m_individualPurityTransformation)->multitoken(),
148 "Activates purity transformation for each feature: Vector of boolean values which decide if the purity transformed of the feature should be added in addition to this training.")
149#endif
150 ("randRatio", po::value<double>(&m_randRatio)->notifier(check_bounds<double>(0.0, 1.0001, "randRatio")),
151 "Fraction of the data sampled each training iteration. Reasonable values are between 0.1 and 1.0.");
152 return description;
153 }
154
155
157 const FastBDTOptions& specific_options) : Teacher(general_options),
158 m_specific_options(specific_options) { }
159
161 {
162
163 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
164#if FastBDT_VERSION_MAJOR >= 4
165 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
166#else
167 // Deactivate support for spectators below version 4!
168 unsigned int numberOfSpectators = 0;
169#endif
170
171 // FastBDT Version 4 has a simplified interface with a sklearn style Classifier
172#if FastBDT_VERSION_MAJOR >= 5
173 if (m_specific_options.m_individual_nCuts.size() != 0
174 and m_specific_options.m_individual_nCuts.size() != numberOfFeatures + numberOfSpectators) {
175 B2ERROR("You provided individual nCut values for each feature and spectator, but the total number of provided cuts is not same as as the total number of features and spectators.");
176 }
177
178 std::vector<bool> individualPurityTransformation = m_specific_options.m_individualPurityTransformation;
179 if (m_specific_options.m_purityTransformation) {
180 if (individualPurityTransformation.size() == 0) {
181 for (unsigned int i = 0; i < numberOfFeatures; ++i) {
182 individualPurityTransformation.push_back(true);
183 }
184 }
185 }
186
187 std::vector<unsigned int> individual_nCuts = m_specific_options.m_individual_nCuts;
188 if (individual_nCuts.size() == 0) {
189 for (unsigned int i = 0; i < numberOfFeatures + numberOfSpectators; ++i) {
190 individual_nCuts.push_back(m_specific_options.m_nCuts);
191 }
192 }
193
194 FastBDT::Classifier classifier(m_specific_options.m_nTrees, m_specific_options.m_nLevels, individual_nCuts,
196 m_specific_options.m_sPlot, m_specific_options.m_flatnessLoss, individualPurityTransformation,
197 numberOfSpectators, true);
198
199 std::vector<std::vector<float>> X(numberOfFeatures + numberOfSpectators);
200 const auto& y = training_data.getSignals();
201 if (not isValidSignal(y)) {
202 B2FATAL("The training data is not valid. It only contains one class instead of two.");
203 }
204 const auto& w = training_data.getWeights();
205 for (unsigned int i = 0; i < numberOfFeatures; ++i) {
206 X[i] = training_data.getFeature(i);
207 }
208 for (unsigned int i = 0; i < numberOfSpectators; ++i) {
209 X[i + numberOfFeatures] = training_data.getSpectator(i);
210 }
211 classifier.fit(X, y, w);
212#else
213 const auto& y = training_data.getSignals();
214 if (not isValidSignal(y)) {
215 B2FATAL("The training data is not valid. It only contains one class instead of two.");
216 }
217 std::vector<FastBDT::FeatureBinning<float>> featureBinnings;
218 std::vector<unsigned int> nBinningLevels;
219 for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
220 auto feature = training_data.getFeature(iFeature);
221
222 unsigned int nCuts = m_specific_options.m_nCuts;
223#if FastBDT_VERSION_MAJOR >= 3
224 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature));
225#else
226 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature.begin(), feature.end()));
227#endif
228 nBinningLevels.push_back(nCuts);
229 }
230
231 for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
232 auto feature = training_data.getSpectator(iSpectator);
233
234 unsigned int nCuts = m_specific_options.m_nCuts;
235#if FastBDT_VERSION_MAJOR >= 3
236 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature));
237#else
238 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature.begin(), feature.end()));
239#endif
240 nBinningLevels.push_back(nCuts);
241 }
242
243 unsigned int numberOfEvents = training_data.getNumberOfEvents();
244 if (numberOfEvents > 5e+6) {
245 B2WARNING("Number of events for training exceeds 5 million. FastBDT performance starts getting worse when the number reaches O(10^7).");
246 }
247
248#if FastBDT_VERSION_MAJOR >= 4
249 FastBDT::EventSample eventSample(numberOfEvents, numberOfFeatures, numberOfSpectators, nBinningLevels);
250#else
251 FastBDT::EventSample eventSample(numberOfEvents, numberOfFeatures, nBinningLevels);
252#endif
253 std::vector<unsigned int> bins(numberOfFeatures + numberOfSpectators);
254 for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
255 training_data.loadEvent(iEvent);
256 for (unsigned int iFeature = 0; iFeature < numberOfFeatures + numberOfSpectators; ++iFeature) {
257 bins[iFeature] = featureBinnings[iFeature].ValueToBin(training_data.m_input[iFeature]);
258 }
259 for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
260 bins[iSpectator + numberOfFeatures] = featureBinnings[iSpectator + numberOfFeatures].ValueToBin(
261 training_data.m_spectators[iSpectator]);
262 }
263 eventSample.AddEvent(bins, training_data.m_weight, training_data.m_isSignal);
264 }
265
268#if FastBDT_VERSION_MAJOR >= 3
269 FastBDT::Forest<float> forest(dt.GetShrinkage(), dt.GetF0(), true);
270#else
271 FastBDT::Forest forest(dt.GetShrinkage(), dt.GetF0());
272#endif
273 for (auto t : dt.GetForest()) {
274#if FastBDT_VERSION_MAJOR >= 3
275 auto tree = FastBDT::removeFeatureBinningTransformationFromTree(t, featureBinnings);
276 forest.AddTree(tree);
277#else
278 forest.AddTree(t);
279#endif
280 }
281
282#endif
283
284
285 Weightfile weightfile;
286 std::string custom_weightfile = weightfile.generateFileName();
287 std::fstream file(custom_weightfile, std::ios_base::out | std::ios_base::trunc);
288
289#if FastBDT_VERSION_MAJOR >= 5
290 file << classifier << std::endl;
291#else
292#if FastBDT_VERSION_MAJOR >= 3
293 file << forest << std::endl;
294#else
295 file << featureBinnings << std::endl;
296 file << forest << std::endl;
297#endif
298#endif
299 file.close();
300
301 weightfile.addOptions(m_general_options);
302 weightfile.addOptions(m_specific_options);
303 weightfile.addFile("FastBDT_Weightfile", custom_weightfile);
304 weightfile.addSignalFraction(training_data.getSignalFraction());
305
306 std::map<std::string, float> importance;
307#if FastBDT_VERSION_MAJOR >= 5
308 for (auto& pair : classifier.GetVariableRanking()) {
309 importance[m_general_options.m_variables[pair.first]] = pair.second;
310 }
311#else
312 for (auto& pair : forest.GetVariableRanking()) {
313 importance[m_general_options.m_variables[pair.first]] = pair.second;
314 }
315#endif
316 weightfile.addFeatureImportance(importance);
317
318 return weightfile;
319
320 }
321
323 {
324
325 std::string custom_weightfile = weightfile.generateFileName();
326 weightfile.getFile("FastBDT_Weightfile", custom_weightfile);
327 std::fstream file(custom_weightfile, std::ios_base::in);
328
329 int version = weightfile.getElement<int>("FastBDT_version", 0);
330 B2DEBUG(100, "FastBDT Weightfile Version " << version);
331 if (version < 2) {
332#if FastBDT_VERSION_MAJOR >= 3
333 std::stringstream s;
334 {
335 std::string t;
336 std::fstream file2(custom_weightfile, std::ios_base::in);
337 getline(file2, t);
338 s << t;
339 }
340 int dummy;
341 // Try to read to integers, if this is successful we have a old weightfile with a Feature Binning before the Tree.
342 if (!(s >> dummy >> dummy)) {
343 B2DEBUG(100, "FastBDT: I read a new weightfile of FastBDT using the new FastBDT version 3. Everything fine!");
344 // New format since version 3
345 m_expert_forest = FastBDT::readForestFromStream<float>(file);
346 } else {
347 B2INFO("FastBDT: I read an old weightfile of FastBDT using the new FastBDT version 3."
348 "I will convert your FastBDT on-the-fly to the new version."
349 "Retrain the classifier to get rid of this message");
350 // Old format before version 3
351 // We read in first the feature binnings and than rewrite the tree
352 std::vector<FastBDT::FeatureBinning<float>> feature_binnings;
353 file >> feature_binnings;
354 double F0;
355 file >> F0;
356 double shrinkage;
357 file >> shrinkage;
358 // This parameter was not available in the old version
359 bool transform2probability = true;
360 FastBDT::Forest<unsigned int> temp_forest(shrinkage, F0, transform2probability);
361 unsigned int size;
362 file >> size;
363 for (unsigned int i = 0; i < size; ++i) {
364 temp_forest.AddTree(FastBDT::readTreeFromStream<unsigned int>(file));
365 }
366
367 FastBDT::Forest<float> cleaned_forest(temp_forest.GetShrinkage(), temp_forest.GetF0(), temp_forest.GetTransform2Probability());
368 for (auto& tree : temp_forest.GetForest()) {
369 cleaned_forest.AddTree(FastBDT::removeFeatureBinningTransformationFromTree(tree, feature_binnings));
370 }
371 m_expert_forest = cleaned_forest;
372 }
373#else
374 B2INFO("FastBDT: I read an old weightfile of FastBDT using the old FastBDT version."
375 "I try to fix the weightfile first to avoid problems due to NaN and inf values."
376 "Consider to switch to the newer version of FastBDT (newer externals)");
377 // Check for nan or inf in file and replace with 0
378 std::stringstream s;
379 std::string t;
380 while (getline(file, t)) {
381 size_t f = 0;
382
383 while ((f = t.find("inf", f)) != std::string::npos) {
384 t.replace(f, std::string("inf").length(), std::string("0.0"));
385 f += std::string("0.0").length();
386 B2WARNING("Found infinity in FastBDT weightfile, I replace it with 0 to prevent horrible crashes, this is fixed in the newer version");
387 }
388 f = 0;
389 while ((f = t.find("nan", f)) != std::string::npos) {
390 t.replace(f, std::string("nan").length(), std::string("0.0"));
391 f += std::string("0.0").length();
392 B2WARNING("Found nan in FastBDT weightfile, I replace it with 0 to prevent horrible crashes, this is fixed in the newer version");
393 }
394 s << t + '\n';
395 }
397 m_expert_forest = FastBDT::readForestFromStream(s);
398#endif
399 }
400#if FastBDT_VERSION_MAJOR >= 5
401 else {
402 m_use_simplified_interface = true;
403 m_classifier = FastBDT::Classifier(file);
404 }
405#else
406 else {
407 B2ERROR("Unknown Version 2 of Weightfile, please use a more recent FastBDT version");
408 }
409#endif
410 file.close();
411
412 weightfile.getOptions(m_specific_options);
413 }
414
415 std::vector<float> FastBDTExpert::apply(Dataset& test_data) const
416 {
417
418 std::vector<float> probabilities(test_data.getNumberOfEvents());
419 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
420 test_data.loadEvent(iEvent);
421#if FastBDT_VERSION_MAJOR >= 3
422#if FastBDT_VERSION_MAJOR >= 5
423 if (m_use_simplified_interface)
424 probabilities[iEvent] = m_classifier.predict(test_data.m_input);
425 else
426 probabilities[iEvent] = m_expert_forest.Analyse(test_data.m_input);
427#else
428 probabilities[iEvent] = m_expert_forest.Analyse(test_data.m_input);
429#endif
430#else
431 std::vector<unsigned int> bins(m_expert_feature_binning.size());
432 for (unsigned int iFeature = 0; iFeature < m_expert_feature_binning.size(); ++iFeature) {
433 bins[iFeature] = m_expert_feature_binning[iFeature].ValueToBin(test_data.m_input[iFeature]);
434 }
435 probabilities[iEvent] = m_expert_forest.Analyse(bins);
436#endif
437 }
438
439 return probabilities;
440
441 }
442
443 }
445}
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
std::vector< FastBDT::FeatureBinning< float > > m_expert_feature_binning
Forest feature binning.
Definition: FastBDT.h:147
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FastBDT.cc:415
FastBDT::Forest m_expert_forest
Forest Expert.
Definition: FastBDT.h:146
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FastBDT.cc:322
FastBDTOptions m_specific_options
Method specific options.
Definition: FastBDT.h:138
Options for the FANN MVA method.
Definition: FastBDT.h:53
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: FastBDT.cc:126
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:82
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:81
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FastBDT.cc:53
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:80
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: FastBDT.cc:99
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:79
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:78
FastBDTTeacher(const GeneralOptions &general_options, const FastBDTOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: FastBDT.cc:156
FastBDTOptions m_specific_options
Method specific options.
Definition: FastBDT.h:115
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: FastBDT.cc:160
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:151
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:72
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24