9 #include <mva/methods/FastBDT.h>
11 #include <framework/logging/Logger.h>
16 #if FastBDT_VERSION_MAJOR <= 3 && FastBDT_VERSION_MINOR <= 2
19 bool compareIncludingNaN(
float i,
float j)
43 bool isValidSignal(
const std::vector<bool>& Signals)
45 const auto first = Signals.front();
46 for (
const auto& value : Signals) {
55 int version = pt.get<
int>(
"FastBDT_version");
56 #if FastBDT_VERSION_MAJOR >= 5
57 if (version != 1 and version != 2) {
58 B2ERROR(
"Unknown weightfile version " << std::to_string(version));
59 throw std::runtime_error(
"Unknown weightfile version " + std::to_string(version));
63 B2ERROR(
"Unknown weightfile version " << std::to_string(version));
64 throw std::runtime_error(
"Unknown weightfile version " + std::to_string(version));
67 m_nTrees = pt.get<
int>(
"FastBDT_nTrees");
68 m_nCuts = pt.get<
int>(
"FastBDT_nCuts");
69 m_nLevels = pt.get<
int>(
"FastBDT_nLevels");
73 #if FastBDT_VERSION_MAJOR >= 5
76 m_flatnessLoss = pt.get<
double>(
"FastBDT_flatnessLoss");
77 m_sPlot = pt.get<
bool>(
"FastBDT_sPlot");
79 unsigned int numberOfIndividualNCuts = pt.get<
unsigned int>(
"FastBDT_number_individual_nCuts", 0);
80 m_individual_nCuts.resize(numberOfIndividualNCuts);
81 for (
unsigned int i = 0; i < numberOfIndividualNCuts; ++i) {
82 m_individual_nCuts[i] = pt.get<
unsigned int>(std::string(
"FastBDT_individual_nCuts") + std::to_string(i));
85 m_purityTransformation = pt.get<
bool>(
"FastBDT_purityTransformation");
86 unsigned int numberOfIndividualPurityTransformation = pt.get<
unsigned int>(
"FastBDT_number_individualPurityTransformation", 0);
87 m_individualPurityTransformation.resize(numberOfIndividualPurityTransformation);
88 for (
unsigned int i = 0; i < numberOfIndividualPurityTransformation; ++i) {
89 m_individualPurityTransformation[i] = pt.get<
bool>(std::string(
"FastBDT_individualPurityTransformation") + std::to_string(i));
93 m_flatnessLoss = -1.0;
101 #if FastBDT_VERSION_MAJOR >= 5
102 pt.put(
"FastBDT_version", 2);
104 pt.put(
"FastBDT_version", 1);
107 pt.put(
"FastBDT_nCuts",
m_nCuts);
111 #if FastBDT_VERSION_MAJOR >= 5
112 pt.put(
"FastBDT_flatnessLoss", m_flatnessLoss);
113 pt.put(
"FastBDT_sPlot", m_sPlot);
114 pt.put(
"FastBDT_number_individual_nCuts", m_individual_nCuts.size());
115 for (
unsigned int i = 0; i < m_individual_nCuts.size(); ++i) {
116 pt.put(std::string(
"FastBDT_individual_nCuts") + std::to_string(i), m_individual_nCuts[i]);
118 pt.put(
"FastBDT_purityTransformation", m_purityTransformation);
119 pt.put(
"FastBDT_number_individualPurityTransformation", m_individualPurityTransformation.size());
120 for (
unsigned int i = 0; i < m_individualPurityTransformation.size(); ++i) {
121 pt.put(std::string(
"FastBDT_individualPurityTransformation") + std::to_string(i), m_individualPurityTransformation[i]);
128 po::options_description description(
"FastBDT options");
129 description.add_options()
130 (
"nTrees", po::value<unsigned int>(&
m_nTrees),
"Number of trees in the forest. Reasonable values are between 10 and 1000")
131 (
"nLevels", po::value<unsigned int>(&
m_nLevels)->notifier(check_bounds<unsigned int>(0, 20,
"nLevels")),
132 "Depth d of trees. The last layer of the tree will contain 2^d bins. Maximum is 20. Reasonable values are 2 and 6.")
133 (
"shrinkage", po::value<double>(&
m_shrinkage)->notifier(check_bounds<double>(0.0, 1.0,
"shrinkage")),
134 "Shrinkage of the boosting algorithm. Reasonable values are between 0.01 and 1.0.")
135 (
"nCutLevels", po::value<unsigned int>(&
m_nCuts)->notifier(check_bounds<unsigned int>(0, 20,
"nCutLevels")),
136 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12.")
137 #
if FastBDT_VERSION_MAJOR >= 5
138 (
"individualNCutLevels", po::value<std::vector<unsigned int>>(&m_individual_nCuts)->multitoken()->notifier(
139 check_bounds_vector<unsigned int>(0, 20,
"individualNCutLevels")),
140 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12. One value per feature (including spectators) should be provided, if parameter is not set the global value specified by nCutLevels is used for all features.")
141 (
"sPlot", po::value<bool>(&m_sPlot),
142 "Since in sPlot each event enters twice, this option modifies the sampling algorithm so that the matching signal and background events are selected together.")
143 (
"flatnessLoss", po::value<double>(&m_flatnessLoss),
144 "Activate Flatness Loss, all spectator variables are assumed to be variables in which the signal and background efficiency should be flat. negative values deactivates flatness loss.")
145 (
"purityTransformation", po::value<bool>(&m_purityTransformation),
146 "Activates purity transformation on all features: Add the purity transformed of all features in addition to the training. This will double the number of features and slow down the inference considerably")
147 (
"individualPurityTransformation", po::value<std::vector<bool>>(&m_individualPurityTransformation)->multitoken(),
148 "Activates purity transformation for each feature: Vector of boolean values which decide if the purity transformed of the feature should be added in addition to this training.")
150 (
"randRatio", po::value<double>(&
m_randRatio)->notifier(check_bounds<double>(0.0, 1.0001,
"randRatio")),
151 "Fraction of the data sampled each training iteration. Reasonable values are between 0.1 and 1.0.");
158 m_specific_options(specific_options) { }
163 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
164 #if FastBDT_VERSION_MAJOR >= 4
165 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
168 unsigned int numberOfSpectators = 0;
172 #if FastBDT_VERSION_MAJOR >= 5
174 and
m_specific_options.m_individual_nCuts.size() != numberOfFeatures + numberOfSpectators) {
175 B2ERROR(
"You provided individual nCut values for each feature and spectator, but the total number of provided cuts is not same as as the total number of features and spectators.");
178 std::vector<bool> individualPurityTransformation =
m_specific_options.m_individualPurityTransformation;
180 if (individualPurityTransformation.size() == 0) {
181 for (
unsigned int i = 0; i < numberOfFeatures; ++i) {
182 individualPurityTransformation.push_back(
true);
188 if (individual_nCuts.size() == 0) {
189 for (
unsigned int i = 0; i < numberOfFeatures + numberOfSpectators; ++i) {
197 numberOfSpectators,
true);
199 std::vector<std::vector<float>> X(numberOfFeatures + numberOfSpectators);
200 const auto& y = training_data.getSignals();
201 if (not isValidSignal(y)) {
202 B2FATAL(
"The training data is not valid. It only contains one class instead of two.");
204 const auto& w = training_data.getWeights();
205 for (
unsigned int i = 0; i < numberOfFeatures; ++i) {
206 X[i] = training_data.getFeature(i);
208 for (
unsigned int i = 0; i < numberOfSpectators; ++i) {
209 X[i + numberOfFeatures] = training_data.getSpectator(i);
211 classifier.fit(X, y, w);
213 const auto& y = training_data.getSignals();
214 if (not isValidSignal(y)) {
215 B2FATAL(
"The training data is not valid. It only contains one class instead of two.");
217 std::vector<FastBDT::FeatureBinning<float>> featureBinnings;
218 std::vector<unsigned int> nBinningLevels;
219 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
220 auto feature = training_data.getFeature(iFeature);
223 #if FastBDT_VERSION_MAJOR >= 3
224 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature));
226 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature.begin(), feature.end()));
228 nBinningLevels.push_back(nCuts);
231 for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
232 auto feature = training_data.getSpectator(iSpectator);
235 #if FastBDT_VERSION_MAJOR >= 3
236 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature));
238 featureBinnings.push_back(FastBDT::FeatureBinning<float>(nCuts, feature.begin(), feature.end()));
240 nBinningLevels.push_back(nCuts);
243 unsigned int numberOfEvents = training_data.getNumberOfEvents();
244 if (numberOfEvents > 5e+6) {
245 B2WARNING(
"Number of events for training exceeds 5 million. FastBDT performance starts getting worse when the number reaches O(10^7).");
248 #if FastBDT_VERSION_MAJOR >= 4
249 FastBDT::EventSample eventSample(numberOfEvents, numberOfFeatures, numberOfSpectators, nBinningLevels);
251 FastBDT::EventSample eventSample(numberOfEvents, numberOfFeatures, nBinningLevels);
253 std::vector<unsigned int> bins(numberOfFeatures + numberOfSpectators);
254 for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
255 training_data.loadEvent(iEvent);
256 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures + numberOfSpectators; ++iFeature) {
257 bins[iFeature] = featureBinnings[iFeature].ValueToBin(training_data.m_input[iFeature]);
259 for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
260 bins[iSpectator + numberOfFeatures] = featureBinnings[iSpectator + numberOfFeatures].ValueToBin(
261 training_data.m_spectators[iSpectator]);
263 eventSample.AddEvent(bins, training_data.m_weight, training_data.m_isSignal);
268 #if FastBDT_VERSION_MAJOR >= 3
269 FastBDT::Forest<float> forest(dt.GetShrinkage(), dt.GetF0(),
true);
271 FastBDT::Forest forest(dt.GetShrinkage(), dt.GetF0());
273 for (
auto t : dt.GetForest()) {
274 #if FastBDT_VERSION_MAJOR >= 3
275 auto tree = FastBDT::removeFeatureBinningTransformationFromTree(t, featureBinnings);
276 forest.AddTree(tree);
287 std::fstream file(custom_weightfile, std::ios_base::out | std::ios_base::trunc);
289 #if FastBDT_VERSION_MAJOR >= 5
290 file << classifier << std::endl;
292 #if FastBDT_VERSION_MAJOR >= 3
293 file << forest << std::endl;
295 file << featureBinnings << std::endl;
296 file << forest << std::endl;
303 weightfile.
addFile(
"FastBDT_Weightfile", custom_weightfile);
306 std::map<std::string, float> importance;
307 #if FastBDT_VERSION_MAJOR >= 5
308 for (
auto& pair : classifier.GetVariableRanking()) {
312 for (
auto& pair : forest.GetVariableRanking()) {
326 weightfile.
getFile(
"FastBDT_Weightfile", custom_weightfile);
327 std::fstream file(custom_weightfile, std::ios_base::in);
329 int version = weightfile.
getElement<
int>(
"FastBDT_version", 0);
330 B2DEBUG(100,
"FastBDT Weightfile Version " << version);
332 #if FastBDT_VERSION_MAJOR >= 3
336 std::fstream file2(custom_weightfile, std::ios_base::in);
342 if (!(s >> dummy >> dummy)) {
343 B2DEBUG(100,
"FastBDT: I read a new weightfile of FastBDT using the new FastBDT version 3. Everything fine!");
347 B2INFO(
"FastBDT: I read an old weightfile of FastBDT using the new FastBDT version 3."
348 "I will convert your FastBDT on-the-fly to the new version."
349 "Retrain the classifier to get rid of this message");
352 std::vector<FastBDT::FeatureBinning<float>> feature_binnings;
353 file >> feature_binnings;
359 bool transform2probability =
true;
360 FastBDT::Forest<unsigned int> temp_forest(shrinkage, F0, transform2probability);
363 for (
unsigned int i = 0; i < size; ++i) {
364 temp_forest.AddTree(FastBDT::readTreeFromStream<unsigned int>(file));
367 FastBDT::Forest<float> cleaned_forest(temp_forest.GetShrinkage(), temp_forest.GetF0(), temp_forest.GetTransform2Probability());
368 for (
auto& tree : temp_forest.GetForest()) {
369 cleaned_forest.AddTree(FastBDT::removeFeatureBinningTransformationFromTree(tree, feature_binnings));
374 B2INFO(
"FastBDT: I read an old weightfile of FastBDT using the old FastBDT version."
375 "I try to fix the weightfile first to avoid problems due to NaN and inf values."
376 "Consider to switch to the newer version of FastBDT (newer externals)");
380 while (getline(file, t)) {
383 while ((f = t.find(
"inf", f)) != std::string::npos) {
384 t.replace(f, std::string(
"inf").length(), std::string(
"0.0"));
385 f += std::string(
"0.0").length();
386 B2WARNING(
"Found infinity in FastBDT weightfile, I replace it with 0 to prevent horrible crashes, this is fixed in the newer version");
389 while ((f = t.find(
"nan", f)) != std::string::npos) {
390 t.replace(f, std::string(
"nan").length(), std::string(
"0.0"));
391 f += std::string(
"0.0").length();
392 B2WARNING(
"Found nan in FastBDT weightfile, I replace it with 0 to prevent horrible crashes, this is fixed in the newer version");
400 #if FastBDT_VERSION_MAJOR >= 5
402 m_use_simplified_interface =
true;
403 m_classifier = FastBDT::Classifier(file);
407 B2ERROR(
"Unknown Version 2 of Weightfile, please use a more recent FastBDT version");
418 std::vector<float> probabilities(test_data.getNumberOfEvents());
419 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
420 test_data.loadEvent(iEvent);
421 #if FastBDT_VERSION_MAJOR >= 3
422 #if FastBDT_VERSION_MAJOR >= 5
423 if (m_use_simplified_interface)
424 probabilities[iEvent] = m_classifier.predict(test_data.m_input);
439 return probabilities;
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
std::vector< FastBDT::FeatureBinning< float > > m_expert_feature_binning
Forest feature binning.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
FastBDT::Forest m_expert_forest
Forest Expert.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
FastBDTOptions m_specific_options
Method specific options.
Options for the FANN MVA method.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
double m_randRatio
Fraction of data to use in the stochastic training.
double m_shrinkage
Shrinkage during the boosting step.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
unsigned int m_nLevels
Depth of tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
unsigned int m_nTrees
Number of trees.
FastBDTTeacher(const GeneralOptions &general_options, const FastBDTOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
FastBDTOptions m_specific_options
Method specific options.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
General options which are shared by all MVA trainings.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
GeneralOptions m_general_options
GeneralOptions containing all shared options.
The Weightfile class serializes all information about a training into an xml tree.
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
void addOptions(const Options &options)
Add an Option object to the xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Abstract base class for different kinds of events.