11 #include <mva/methods/FastBDT.h>
13 #include <framework/logging/Logger.h>
18 #if FastBDT_VERSION_MAJOR <= 3 && FastBDT_VERSION_MINOR <= 2
21 bool compareIncludingNaN(
float i,
float j)
45 bool isValidSignal(
const std::vector<bool>& Signals)
47 const auto first = Signals.front();
48 for (
const auto& value : Signals) {
57 int version = pt.get<
int>(
"FastBDT_version");
58 #if FastBDT_VERSION_MAJOR >= 5
59 if (version != 1 and version != 2) {
60 B2ERROR(
"Unkown weightfile version " << std::to_string(version));
61 throw std::runtime_error(
"Unkown weightfile version " + std::to_string(version));
65 B2ERROR(
"Unkown weightfile version " << std::to_string(version));
66 throw std::runtime_error(
"Unkown weightfile version " + std::to_string(version));
69 m_nTrees = pt.get<
int>(
"FastBDT_nTrees");
70 m_nCuts = pt.get<
int>(
"FastBDT_nCuts");
71 m_nLevels = pt.get<
int>(
"FastBDT_nLevels");
75 #if FastBDT_VERSION_MAJOR >= 5
78 m_flatnessLoss = pt.get<
double>(
"FastBDT_flatnessLoss");
79 m_sPlot = pt.get<
bool>(
"FastBDT_sPlot");
81 unsigned int numberOfIndividualNCuts = pt.get<
unsigned int>(
"FastBDT_number_individual_nCuts", 0);
82 m_individual_nCuts.resize(numberOfIndividualNCuts);
83 for (
unsigned int i = 0; i < numberOfIndividualNCuts; ++i) {
84 m_individual_nCuts[i] = pt.get<
unsigned int>(std::string(
"FastBDT_individual_nCuts") + std::to_string(i));
87 m_purityTransformation = pt.get<
bool>(
"FastBDT_purityTransformation");
88 unsigned int numberOfIndividualPurityTransformation = pt.get<
unsigned int>(
"FastBDT_number_individualPurityTransformation", 0);
89 m_individualPurityTransformation.resize(numberOfIndividualPurityTransformation);
90 for (
unsigned int i = 0; i < numberOfIndividualPurityTransformation; ++i) {
91 m_individualPurityTransformation[i] = pt.get<
bool>(std::string(
"FastBDT_individualPurityTransformation") + std::to_string(i));
95 m_flatnessLoss = -1.0;
103 #if FastBDT_VERSION_MAJOR >= 5
104 pt.put(
"FastBDT_version", 2);
106 pt.put(
"FastBDT_version", 1);
109 pt.put(
"FastBDT_nCuts",
m_nCuts);
113 #if FastBDT_VERSION_MAJOR >= 5
114 pt.put(
"FastBDT_flatnessLoss", m_flatnessLoss);
115 pt.put(
"FastBDT_sPlot", m_sPlot);
116 pt.put(
"FastBDT_number_individual_nCuts", m_individual_nCuts.size());
117 for (
unsigned int i = 0; i < m_individual_nCuts.size(); ++i) {
118 pt.put(std::string(
"FastBDT_individual_nCuts") + std::to_string(i), m_individual_nCuts[i]);
120 pt.put(
"FastBDT_purityTransformation", m_purityTransformation);
121 pt.put(
"FastBDT_number_individualPurityTransformation", m_individualPurityTransformation.size());
122 for (
unsigned int i = 0; i < m_individualPurityTransformation.size(); ++i) {
123 pt.put(std::string(
"FastBDT_individualPurityTransformation") + std::to_string(i), m_individualPurityTransformation[i]);
130 po::options_description description(
"FastBDT options");
131 description.add_options()
132 (
"nTrees", po::value<unsigned int>(&
m_nTrees),
"Number of trees in the forest. Reasonable values are between 10 and 1000")
133 (
"nLevels", po::value<unsigned int>(&
m_nLevels)->notifier(check_bounds<unsigned int>(0, 20,
"nLevels")),
134 "Depth d of trees. The last layer of the tree will contain 2^d bins. Maximum is 20. Resonable values are 2 and 6.")
135 (
"shrinkage", po::value<double>(&
m_shrinkage)->notifier(check_bounds<double>(0.0, 1.0,
"shrinkage")),
136 "Shrinkage of the boosting algorithm. Reasonable values are between 0.01 and 1.0.")
137 (
"nCutLevels", po::value<unsigned int>(&
m_nCuts)->notifier(check_bounds<unsigned int>(0, 20,
"nCutLevels")),
138 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12.")
139 #
if FastBDT_VERSION_MAJOR >= 5
140 (
"individualNCutLevels", po::value<std::vector<unsigned int>>(&m_individual_nCuts)->multitoken()->notifier(
141 check_bounds_vector<unsigned int>(0, 20,
"individualNCutLevels")),
142 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12. One value per feature (including spectators) should be provided, if parameter is not set the global value specified by nCutLevels is used for all features.")
143 (
"sPlot", po::value<bool>(&m_sPlot),
144 "Since in sPlot each event enters twice, this option modifies the sampling algorithm so that the matching signal and background events are selected together.")
145 (
"flatnessLoss", po::value<double>(&m_flatnessLoss),
146 "Activate Flatness Loss, all spectator variables are assumed to be variables in which the signal and background efficiency should be flat. negative values deactivates flatness loss.")
147 (
"purityTransformation", po::value<bool>(&m_purityTransformation),
148 "Activates purity transformation on all features: Add the purity transformed of all features in addition to the training. This will double the number of features and slow down the inference considerably")
149 (
"individualPurityTransformation", po::value<std::vector<bool>>(&m_individualPurityTransformation)->multitoken(),
150 "Activates purity transformation for each feature: Vector of boolean values which decide if the purity transformed of the feature should be added in addition to this training.")
152 (
"randRatio", po::value<double>(&
m_randRatio)->notifier(check_bounds<double>(0.0, 1.0001,
"randRatio")),
153 "Fraction of the data sampled each training iteration. Reasonable values are between 0.1 and 1.0.");
160 m_specific_options(specific_options) { }
165 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
166 #if FastBDT_VERSION_MAJOR >= 4
167 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
170 unsigned int numberOfSpectators = 0;
174 #if FastBDT_VERSION_MAJOR >= 5
176 and
m_specific_options.m_individual_nCuts.size() != numberOfFeatures + numberOfSpectators) {
177 B2ERROR(
"You provided individual nCut values for each feature and spectator, but the total number of provided cuts is not same as as the total number of features and spectators.");
180 std::vector<bool> individualPurityTransformation =
m_specific_options.m_individualPurityTransformation;
182 if (individualPurityTransformation.size() == 0) {
183 for (
unsigned int i = 0; i < numberOfFeatures; ++i) {
184 individualPurityTransformation.push_back(
true);
190 if (individual_nCuts.size() == 0) {
191 for (
unsigned int i = 0; i < numberOfFeatures + numberOfSpectators; ++i) {
199 numberOfSpectators,
true);
201 std::vector<std::vector<float>> X(numberOfFeatures + numberOfSpectators);
202 const auto& y = training_data.getSignals();
203 if (not isValidSignal(y)) {
204 B2FATAL(
"The training data is not valid. It only contains one class instead of two.");
206 const auto& w = training_data.getWeights();
207 for (
unsigned int i = 0; i < numberOfFeatures; ++i) {
208 X[i] = training_data.getFeature(i);
210 for (
unsigned int i = 0; i < numberOfSpectators; ++i) {
211 X[i + numberOfFeatures] = training_data.getSpectator(i);
213 classifier.fit(X, y, w);
215 const auto& y = training_data.getSignals();
216 if (not isValidSignal(y)) {
217 B2FATAL(
"The training data is not valid. It only contains one class instead of two.");
219 std::vector<FastBDT::FeatureBinning<float>> featureBinnings;
220 std::vector<unsigned int> nBinningLevels;
221 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
222 auto feature = training_data.getFeature(iFeature);
225 #if FastBDT_VERSION_MAJOR >= 3
226 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature));
228 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature.begin(), feature.end()));
230 nBinningLevels.push_back(nCuts);
233 for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
234 auto feature = training_data.getSpectator(iSpectator);
237 #if FastBDT_VERSION_MAJOR >= 3
238 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature));
240 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature.begin(), feature.end()));
242 nBinningLevels.push_back(nCuts);
245 unsigned int numberOfEvents = training_data.getNumberOfEvents();
247 #if FastBDT_VERSION_MAJOR >= 4
248 FastBDT::EventSample eventSample(numberOfEvents, numberOfFeatures, numberOfSpectators, nBinningLevels);
250 FastBDT::EventSample eventSample(numberOfEvents, numberOfFeatures, nBinningLevels);
252 std::vector<unsigned int> bins(numberOfFeatures + numberOfSpectators);
253 for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
254 training_data.loadEvent(iEvent);
255 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures + numberOfSpectators; ++iFeature) {
256 bins[iFeature] = featureBinnings[iFeature].ValueToBin(training_data.m_input[iFeature]);
258 for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
259 bins[iSpectator + numberOfFeatures] = featureBinnings[iSpectator + numberOfFeatures].ValueToBin(
260 training_data.m_spectators[iSpectator]);
262 eventSample.AddEvent(bins, training_data.m_weight, training_data.m_isSignal);
267 #if FastBDT_VERSION_MAJOR >= 3
268 FastBDT::Forest<float> forest(dt.GetShrinkage(), dt.GetF0(),
true);
270 FastBDT::Forest forest(dt.GetShrinkage(), dt.GetF0());
272 for (
auto t : dt.GetForest()) {
273 #if FastBDT_VERSION_MAJOR >= 3
274 auto tree = FastBDT::removeFeatureBinningTransformationFromTree(t, featureBinnings);
275 forest.AddTree(tree);
286 std::fstream file(custom_weightfile, std::ios_base::out | std::ios_base::trunc);
288 #if FastBDT_VERSION_MAJOR >= 5
289 file << classifier << std::endl;
291 #if FastBDT_VERSION_MAJOR >= 3
292 file << forest << std::endl;
294 file << featureBinnings << std::endl;
295 file << forest << std::endl;
302 weightfile.
addFile(
"FastBDT_Weightfile", custom_weightfile);
305 std::map<std::string, float> importance;
306 #if FastBDT_VERSION_MAJOR >= 5
307 for (
auto& pair : classifier.GetVariableRanking()) {
311 for (
auto& pair : forest.GetVariableRanking()) {
325 weightfile.
getFile(
"FastBDT_Weightfile", custom_weightfile);
326 std::fstream file(custom_weightfile, std::ios_base::in);
328 int version = weightfile.
getElement<
int>(
"FastBDT_version", 0);
329 B2DEBUG(100,
"FastBDT Weightfile Version " << version);
331 #if FastBDT_VERSION_MAJOR >= 3
335 std::fstream file2(custom_weightfile, std::ios_base::in);
341 if (!(s >> dummy >> dummy)) {
342 B2DEBUG(100,
"FastBDT: I read a new weightfile of FastBDT using the new FastBDT version 3. Everythings fine!");
346 B2INFO(
"FastBDT: I read an old weightfile of FastBDT using the new FastBDT version 3."
347 "I will convert your FastBDT on-the-fly to the new version."
348 "Retrain the classifier to get rid of this message");
351 std::vector<FastBDT::FeatureBinning<float>> feature_binnings;
352 file >> feature_binnings;
358 bool transform2probability =
true;
359 FastBDT::Forest<unsigned int> temp_forest(shrinkage, F0, transform2probability);
362 for (
unsigned int i = 0; i < size; ++i) {
363 temp_forest.AddTree(FastBDT::readTreeFromStream<unsigned int>(file));
366 FastBDT::Forest<float> cleaned_forest(temp_forest.GetShrinkage(), temp_forest.GetF0(), temp_forest.GetTransform2Probability());
367 for (
auto& tree : temp_forest.GetForest()) {
368 cleaned_forest.AddTree(FastBDT::removeFeatureBinningTransformationFromTree(tree, feature_binnings));
373 B2INFO(
"FastBDT: I read an old weightfile of FastBDT using the old FastBDT version."
374 "I try to fix the weightfile first to avoid problems due to NaN and inf values."
375 "Consider to switch to the newer version of FastBDT (newer externals)");
379 while (getline(file, t)) {
382 while ((f = t.find(
"inf", f)) != std::string::npos) {
383 t.replace(f, std::string(
"inf").length(), std::string(
"0.0"));
384 f += std::string(
"0.0").length();
385 B2WARNING(
"Found infinity in FastBDT weightfile, I replace it with 0 to prevent horrible crashes, this is fixed in the newer version");
388 while ((f = t.find(
"nan", f)) != std::string::npos) {
389 t.replace(f, std::string(
"nan").length(), std::string(
"0.0"));
390 f += std::string(
"0.0").length();
391 B2WARNING(
"Found nan in FastBDT weightfile, I replace it with 0 to prevent horrible crashes, this is fixed in the newer version");
399 #if FastBDT_VERSION_MAJOR >= 5
401 m_use_simplified_interface =
true;
402 m_classifier = FastBDT::Classifier(file);
406 B2ERROR(
"Unknown Version 2 of Weightfile, please use a more recent FastBDT version");
417 std::vector<float> probabilities(test_data.getNumberOfEvents());
418 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
419 test_data.loadEvent(iEvent);
420 #if FastBDT_VERSION_MAJOR >= 3
421 #if FastBDT_VERSION_MAJOR >= 5
422 if (m_use_simplified_interface)
423 probabilities[iEvent] = m_classifier.predict(test_data.m_input);
438 return probabilities;