9#include <mva/methods/FastBDT.h>
11#include <framework/logging/Logger.h>
21 bool isValidSignal(
const std::vector<bool>& Signals)
23 const auto first = Signals.front();
24 for (
const auto& value : Signals) {
33 int version = pt.get<
int>(
"FastBDT_version");
34 if (version != 1 and version != 2) {
35 B2ERROR(
"Unknown weightfile version " << std::to_string(version));
36 throw std::runtime_error(
"Unknown weightfile version " + std::to_string(version));
38 m_nTrees = pt.get<
int>(
"FastBDT_nTrees");
39 m_nCuts = pt.get<
int>(
"FastBDT_nCuts");
40 m_nLevels = pt.get<
int>(
"FastBDT_nLevels");
47 m_sPlot = pt.get<
bool>(
"FastBDT_sPlot");
49 unsigned int numberOfIndividualNCuts = pt.get<
unsigned int>(
"FastBDT_number_individual_nCuts", 0);
51 for (
unsigned int i = 0; i < numberOfIndividualNCuts; ++i) {
52 m_individual_nCuts[i] = pt.get<
unsigned int>(std::string(
"FastBDT_individual_nCuts") + std::to_string(i));
56 unsigned int numberOfIndividualPurityTransformation = pt.get<
unsigned int>(
"FastBDT_number_individualPurityTransformation", 0);
58 for (
unsigned int i = 0; i < numberOfIndividualPurityTransformation; ++i) {
70 pt.put(
"FastBDT_version", 2);
72 pt.put(
"FastBDT_nCuts",
m_nCuts);
77 pt.put(
"FastBDT_sPlot",
m_sPlot);
80 pt.put(std::string(
"FastBDT_individual_nCuts") + std::to_string(i),
m_individual_nCuts[i]);
91 po::options_description description(
"FastBDT options");
92 description.add_options()
93 (
"nTrees", po::value<unsigned int>(&
m_nTrees),
"Number of trees in the forest. Reasonable values are between 10 and 1000")
94 (
"nLevels", po::value<unsigned int>(&
m_nLevels)->notifier(check_bounds<unsigned int>(0, 20,
"nLevels")),
95 "Depth d of trees. The last layer of the tree will contain 2^d bins. Maximum is 20. Reasonable values are 2 and 6.")
96 (
"shrinkage", po::value<double>(&
m_shrinkage)->notifier(check_bounds<double>(0.0, 1.0,
"shrinkage")),
97 "Shrinkage of the boosting algorithm. Reasonable values are between 0.01 and 1.0.")
98 (
"nCutLevels", po::value<unsigned int>(&
m_nCuts)->notifier(check_bounds<unsigned int>(0, 20,
"nCutLevels")),
99 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12.")
100 (
"individualNCutLevels", po::value<std::vector<unsigned int>>(&
m_individual_nCuts)->multitoken()->notifier(
101 check_bounds_vector<unsigned int>(0, 20,
"individualNCutLevels")),
102 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12. One value per feature (including spectators) should be provided, if parameter is not set the global value specified by nCutLevels is used for all features.")
103 (
"sPlot", po::value<bool>(&
m_sPlot),
104 "Since in sPlot each event enters twice, this option modifies the sampling algorithm so that the matching signal and background events are selected together.")
106 "Activate Flatness Loss, all spectator variables are assumed to be variables in which the signal and background efficiency should be flat. negative values deactivates flatness loss.")
108 "Activates purity transformation on all features: Add the purity transformed of all features in addition to the training. This will double the number of features and slow down the inference considerably")
110 "Activates purity transformation for each feature: Vector of boolean values which decide if the purity transformed of the feature should be added in addition to this training.")
111 (
"randRatio", po::value<double>(&
m_randRatio)->notifier(check_bounds<double>(0.0, 1.0001,
"randRatio")),
112 "Fraction of the data sampled each training iteration. Reasonable values are between 0.1 and 1.0.");
119 m_specific_options(specific_options) { }
123 if (training_data.getNumberOfEvents() > 5e+6) {
124 B2WARNING(
"Number of events for training exceeds 5 million. FastBDT performance starts getting worse when the number reaches O(10^7).");
127 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
128 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
132 B2ERROR(
"You provided individual nCut values for each feature and spectator, but the total number of provided cuts is not same as as the total number of features and spectators.");
137 if (individualPurityTransformation.size() == 0) {
138 for (
unsigned int i = 0; i < numberOfFeatures; ++i) {
139 individualPurityTransformation.push_back(
true);
145 if (individual_nCuts.size() == 0) {
146 for (
unsigned int i = 0; i < numberOfFeatures + numberOfSpectators; ++i) {
154 numberOfSpectators,
true);
156 std::vector<std::vector<float>> X(numberOfFeatures + numberOfSpectators);
157 const auto& y = training_data.getSignals();
158 if (not isValidSignal(y)) {
159 B2FATAL(
"The training data is not valid. It only contains one class instead of two.");
161 const auto& w = training_data.getWeights();
162 for (
unsigned int i = 0; i < numberOfFeatures; ++i) {
163 X[i] = training_data.getFeature(i);
165 for (
unsigned int i = 0; i < numberOfSpectators; ++i) {
166 X[i + numberOfFeatures] = training_data.getSpectator(i);
168 classifier.fit(X, y, w);
172 std::fstream file(custom_weightfile, std::ios_base::out | std::ios_base::trunc);
174 file << classifier << std::endl;
179 weightfile.
addFile(
"FastBDT_Weightfile", custom_weightfile);
182 std::map<std::string, float> importance;
183 for (
auto& pair : classifier.GetVariableRanking()) {
196 weightfile.
getFile(
"FastBDT_Weightfile", custom_weightfile);
197 std::fstream file(custom_weightfile, std::ios_base::in);
199 int version = weightfile.
getElement<
int>(
"FastBDT_version", 0);
200 B2DEBUG(100,
"FastBDT Weightfile Version " << version);
205 std::fstream file2(custom_weightfile, std::ios_base::in);
211 if (!(s >> dummy >> dummy)) {
212 B2DEBUG(100,
"FastBDT: I read a new weightfile of FastBDT using the new FastBDT version 3. Everything fine!");
216 B2INFO(
"FastBDT: I read an old weightfile of FastBDT using the new FastBDT version 3."
217 "I will convert your FastBDT on-the-fly to the new version."
218 "Retrain the classifier to get rid of this message");
221 std::vector<FastBDT::FeatureBinning<float>> feature_binnings;
222 file >> feature_binnings;
228 bool transform2probability =
true;
229 FastBDT::Forest<unsigned int> temp_forest(shrinkage, F0, transform2probability);
232 for (
unsigned int i = 0; i < size; ++i) {
233 temp_forest.AddTree(FastBDT::readTreeFromStream<unsigned int>(file));
236 FastBDT::Forest<float> cleaned_forest(temp_forest.GetShrinkage(), temp_forest.GetF0(), temp_forest.GetTransform2Probability());
237 for (
auto& tree : temp_forest.GetForest()) {
238 cleaned_forest.AddTree(FastBDT::removeFeatureBinningTransformationFromTree(tree, feature_binnings));
254 std::vector<float> probabilities(test_data.getNumberOfEvents());
255 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
256 test_data.loadEvent(iEvent);
258 probabilities[iEvent] =
m_classifier.predict(test_data.m_input);
263 return probabilities;
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
FastBDT::Forest< float > m_expert_forest
Forest Expert -> used in case of no purity transformation.
FastBDT::Classifier m_classifier
Simplified FastBDT interface: classifier combines preprocessing and forest.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
bool m_use_simplified_interface
Use the simplified FastBDT interface of version 4.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
FastBDTOptions m_specific_options
Method specific options.
Options for the FANN MVA method.
std::vector< unsigned int > m_individual_nCuts
Number of cut Levels = log_2(Number of Cuts) for each provided feature.
bool m_sPlot
Activates sPlot sampling.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
double m_randRatio
Fraction of data to use in the stochastic training.
double m_flatnessLoss
Flatness Loss constant.
double m_shrinkage
Shrinkage during the boosting step.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
bool m_purityTransformation
Activates purity transformation globally for all features.
unsigned int m_nLevels
Depth of tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
std::vector< bool > m_individualPurityTransformation
Vector which decided for each feature individually if the purity transformation should be used.
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
unsigned int m_nTrees
Number of trees.
FastBDTTeacher(const GeneralOptions &general_options, const FastBDTOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
FastBDTOptions m_specific_options
Method specific options.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
General options which are shared by all MVA trainings.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
GeneralOptions m_general_options
GeneralOptions containing all shared options.
The Weightfile class serializes all information about a training into an xml tree.
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
void addOptions(const Options &options)
Add an Option object to the xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Abstract base class for different kinds of events.