9#include <mva/methods/FastBDT.h>
11#include <framework/logging/Logger.h>
21 bool isValidSignal(
const std::vector<bool>& Signals)
23 const auto first = Signals.front();
24 for (
const auto& value : Signals) {
33 int version = pt.get<
int>(
"FastBDT_version");
34 if (version != 1 and version != 2) {
35 B2ERROR(
"Unknown weightfile version " << std::to_string(version));
36 throw std::runtime_error(
"Unknown weightfile version " + std::to_string(version));
38 m_nTrees = pt.get<
int>(
"FastBDT_nTrees");
39 m_nCuts = pt.get<
int>(
"FastBDT_nCuts");
40 m_nLevels = pt.get<
int>(
"FastBDT_nLevels");
47 m_sPlot = pt.get<
bool>(
"FastBDT_sPlot");
49 unsigned int numberOfIndividualNCuts = pt.get<
unsigned int>(
"FastBDT_number_individual_nCuts", 0);
51 for (
unsigned int i = 0; i < numberOfIndividualNCuts; ++i) {
52 m_individual_nCuts[i] = pt.get<
unsigned int>(std::string(
"FastBDT_individual_nCuts") + std::to_string(i));
56 unsigned int numberOfIndividualPurityTransformation = pt.get<
unsigned int>(
"FastBDT_number_individualPurityTransformation", 0);
58 for (
unsigned int i = 0; i < numberOfIndividualPurityTransformation; ++i) {
70 pt.put(
"FastBDT_version", 2);
72 pt.put(
"FastBDT_nCuts",
m_nCuts);
77 pt.put(
"FastBDT_sPlot",
m_sPlot);
80 pt.put(std::string(
"FastBDT_individual_nCuts") + std::to_string(i),
m_individual_nCuts[i]);
91 po::options_description description(
"FastBDT options");
92 description.add_options()
93 (
"nTrees", po::value<unsigned int>(&
m_nTrees),
"Number of trees in the forest. Reasonable values are between 10 and 1000")
94 (
"nLevels", po::value<unsigned int>(&
m_nLevels)->notifier(check_bounds<unsigned int>(0, 20,
"nLevels")),
95 "Depth d of trees. The last layer of the tree will contain 2^d bins. Maximum is 20. Reasonable values are 2 and 6.")
96 (
"shrinkage", po::value<double>(&
m_shrinkage)->notifier(check_bounds<double>(0.0, 1.0,
"shrinkage")),
97 "Shrinkage of the boosting algorithm. Reasonable values are between 0.01 and 1.0.")
98 (
"nCutLevels", po::value<unsigned int>(&
m_nCuts)->notifier(check_bounds<unsigned int>(0, 20,
"nCutLevels")),
99 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12.")
100 (
"individualNCutLevels", po::value<std::vector<unsigned int>>(&
m_individual_nCuts)->multitoken()->notifier(
101 check_bounds_vector<unsigned int>(0, 20,
"individualNCutLevels")),
102 "Number of cut levels N per feature. 2^N Bins will be used per feature. Reasonable values are between 6 and 12. One value per feature (including spectators) should be provided, if parameter is not set the global value specified by nCutLevels is used for all features.")
103 (
"sPlot", po::value<bool>(&
m_sPlot),
104 "Since in sPlot each event enters twice, this option modifies the sampling algorithm so that the matching signal and background events are selected together.")
106 "Activate Flatness Loss, all spectator variables are assumed to be variables in which the signal and background efficiency should be flat. negative values deactivates flatness loss.")
108 "Activates purity transformation on all features: Add the purity transformed of all features in addition to the training. This will double the number of features and slow down the inference considerably")
110 "Activates purity transformation for each feature: Vector of boolean values which decide if the purity transformed of the feature should be added in addition to this training.")
111 (
"randRatio", po::value<double>(&
m_randRatio)->notifier(check_bounds<double>(0.0, 1.0001,
"randRatio")),
112 "Fraction of the data sampled each training iteration. Reasonable values are between 0.1 and 1.0.");
119 m_specific_options(specific_options) { }
124 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
125 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
129 B2ERROR(
"You provided individual nCut values for each feature and spectator, but the total number of provided cuts is not same as as the total number of features and spectators.");
134 if (individualPurityTransformation.size() == 0) {
135 for (
unsigned int i = 0; i < numberOfFeatures; ++i) {
136 individualPurityTransformation.push_back(
true);
142 if (individual_nCuts.size() == 0) {
143 for (
unsigned int i = 0; i < numberOfFeatures + numberOfSpectators; ++i) {
151 numberOfSpectators,
true);
153 std::vector<std::vector<float>> X(numberOfFeatures + numberOfSpectators);
154 const auto& y = training_data.getSignals();
155 if (not isValidSignal(y)) {
156 B2FATAL(
"The training data is not valid. It only contains one class instead of two.");
158 const auto& w = training_data.getWeights();
159 for (
unsigned int i = 0; i < numberOfFeatures; ++i) {
160 X[i] = training_data.getFeature(i);
162 for (
unsigned int i = 0; i < numberOfSpectators; ++i) {
163 X[i + numberOfFeatures] = training_data.getSpectator(i);
165 classifier.fit(X, y, w);
169 std::fstream file(custom_weightfile, std::ios_base::out | std::ios_base::trunc);
171 file << classifier << std::endl;
176 weightfile.
addFile(
"FastBDT_Weightfile", custom_weightfile);
179 std::map<std::string, float> importance;
180 for (
auto& pair : classifier.GetVariableRanking()) {
193 weightfile.
getFile(
"FastBDT_Weightfile", custom_weightfile);
194 std::fstream file(custom_weightfile, std::ios_base::in);
196 int version = weightfile.
getElement<
int>(
"FastBDT_version", 0);
197 B2DEBUG(100,
"FastBDT Weightfile Version " << version);
202 std::fstream file2(custom_weightfile, std::ios_base::in);
208 if (!(s >> dummy >> dummy)) {
209 B2DEBUG(100,
"FastBDT: I read a new weightfile of FastBDT using the new FastBDT version 3. Everything fine!");
213 B2INFO(
"FastBDT: I read an old weightfile of FastBDT using the new FastBDT version 3."
214 "I will convert your FastBDT on-the-fly to the new version."
215 "Retrain the classifier to get rid of this message");
218 std::vector<FastBDT::FeatureBinning<float>> feature_binnings;
219 file >> feature_binnings;
225 bool transform2probability =
true;
226 FastBDT::Forest<unsigned int> temp_forest(shrinkage, F0, transform2probability);
229 for (
unsigned int i = 0; i < size; ++i) {
230 temp_forest.AddTree(FastBDT::readTreeFromStream<unsigned int>(file));
233 FastBDT::Forest<float> cleaned_forest(temp_forest.GetShrinkage(), temp_forest.GetF0(), temp_forest.GetTransform2Probability());
234 for (
auto& tree : temp_forest.GetForest()) {
235 cleaned_forest.AddTree(FastBDT::removeFeatureBinningTransformationFromTree(tree, feature_binnings));
251 std::vector<float> probabilities(test_data.getNumberOfEvents());
252 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
253 test_data.loadEvent(iEvent);
255 probabilities[iEvent] =
m_classifier.predict(test_data.m_input);
260 return probabilities;
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
FastBDT::Forest< float > m_expert_forest
Forest Expert -> used in case of no purity transformation.
FastBDT::Classifier m_classifier
Simplified FastBDT interface: classifier combines preprocessing and forest.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
bool m_use_simplified_interface
Use the simplified FastBDT interface of version 4.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
FastBDTOptions m_specific_options
Method specific options.
Options for the FANN MVA method.
std::vector< unsigned int > m_individual_nCuts
Number of cut Levels = log_2(Number of Cuts) for each provided feature.
bool m_sPlot
Activates sPlot sampling.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
double m_randRatio
Fraction of data to use in the stochastic training.
double m_flatnessLoss
Flatness Loss constant.
double m_shrinkage
Shrinkage during the boosting step.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
bool m_purityTransformation
Activates purity transformation globally for all features.
unsigned int m_nLevels
Depth of tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
std::vector< bool > m_individualPurityTransformation
Vector which decided for each feature individually if the purity transformation should be used.
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
unsigned int m_nTrees
Number of trees.
FastBDTTeacher(const GeneralOptions &general_options, const FastBDTOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
FastBDTOptions m_specific_options
Method specific options.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
General options which are shared by all MVA trainings.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
GeneralOptions m_general_options
GeneralOptions containing all shared options.
The Weightfile class serializes all information about a training into an xml tree.
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
void addOptions(const Options &options)
Add an Option object to the xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Abstract base class for different kinds of events.