9 #include <mva/interface/Dataset.h>
11 #include <framework/utilities/MakeROOTCompatible.h>
12 #include <framework/logging/Logger.h>
13 #include <framework/io/RootIOUtilities.h>
17 #include <boost/filesystem/operations.hpp>
38 double signal_weight_sum = 0;
39 double weight_sum = 0;
46 return signal_weight_sum / weight_sum;
55 B2ERROR(
"Unknown feature named " << feature);
67 B2ERROR(
"Unknown spectator named " << spectator);
80 result[iEvent] =
m_input[iFeature];
136 const std::vector<float>& spectators) :
Dataset(general_options)
146 const std::vector<std::vector<float>>& spectators,
147 const std::vector<float>& targets,
const std::vector<float>& weights) :
Dataset(general_options), m_matrix(input),
148 m_spectator_matrix(spectators),
149 m_targets(targets), m_weights(weights)
153 B2ERROR(
"Feature matrix and target vector need same number of elements in MultiDataset, got " <<
m_targets.size() <<
" and " <<
157 B2ERROR(
"Feature matrix and weight vector need same number of elements in MultiDataset, got " <<
m_weights.size() <<
" and " <<
161 B2ERROR(
"Feature matrix and spectator matrix need same number of elements in MultiDataset, got " <<
m_spectator_matrix.size() <<
193 B2ERROR(
"Couldn't find variable " << v <<
" in GeneralOptions");
194 throw std::runtime_error(
"Couldn't find variable " + v +
" in GeneralOptions");
202 B2ERROR(
"Couldn't find spectator " << v <<
" in GeneralOptions");
203 throw std::runtime_error(
"Couldn't find spectator " + v +
" in GeneralOptions");
208 if (events.size() > 0)
213 unsigned int n_events = 0;
215 if (events.size() == 0 or events[iEvent]) {
227 unsigned int index = iEvent;
235 for (
unsigned int iFeature = 0; iFeature <
m_input.size(); ++iFeature) {
239 for (
unsigned int iSpectator = 0; iSpectator <
m_spectators.size(); ++iSpectator) {
274 Dataset& background_dataset) :
Dataset(general_options), m_signal_dataset(signal_dataset),
275 m_background_dataset(background_dataset) { }
301 s.insert(s.end(), b.begin(), b.end());
311 s.insert(s.end(), b.begin(), b.end());
325 for (
const auto& variable : general_options.m_variables)
326 for (
const auto& spectator : general_options.m_spectators)
327 if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
328 B2ERROR(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
329 throw std::runtime_error(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
332 std::vector<std::string> filenames;
334 if (boost::filesystem::exists(filename)) {
335 filenames.push_back(filename);
338 filenames.insert(filenames.end(), temp.begin(), temp.end());
341 if (filenames.empty()) {
342 B2ERROR(
"Found no valid filenames in GeneralOptions");
343 throw std::runtime_error(
"Found no valid filenames in GeneralOptions");
347 TDirectory* dir = gDirectory;
348 for (
const auto& filename : filenames) {
349 if (not boost::filesystem::exists(filename)) {
350 B2ERROR(
"Error given ROOT file does not exist " << filename);
351 throw std::runtime_error(
"Error during open of ROOT file named " + filename);
354 TFile* f = TFile::Open(filename.c_str(),
"READ");
355 if (!f or f->IsZombie() or not f->IsOpen()) {
356 B2ERROR(
"Error during open of ROOT file named " << filename);
357 throw std::runtime_error(
"Error during open of ROOT file named " + filename);
364 for (
const auto& filename : filenames) {
366 if (!
m_tree->AddFile(filename.c_str(), -1)) {
367 B2ERROR(
"Error during open of ROOT file named " << filename <<
" cannot retrieve tree named " <<
369 throw std::runtime_error(
"Error during open of ROOT file named " + filename +
" cannot retrieve tree named " +
380 if (
m_tree->GetEntry(event, 0) == 0) {
381 B2ERROR(
"Error during loading entry from chain");
418 if (branchName.empty()) {
419 B2INFO(
"No TBranch name given for weights. Using 1s as default weights.");
421 std::vector<float> values(nentries, 1.);
424 if (branchName ==
"__weight__") {
426 B2INFO(
"No default weight branch with name __weight__ found. Using 1s as default weights.");
428 std::vector<float> values(nentries, 1.);
433 std::string typeName =
"weights";
446 B2ERROR(
"Feature index " << iFeature <<
" is out of bounds of given number of features: "
451 std::string typeName =
"features";
468 B2ERROR(
"Spectator index " << iSpectator <<
" is out of bounds of given number of spectators: "
473 std::string typeName =
"spectators";
495 T& memberVariableTarget)
498 std::vector<float> values(nentries);
502 auto currentTreeNumber =
m_tree->GetTreeNumber();
503 TBranch* branch =
m_tree->GetBranch(branchName.c_str());
505 B2ERROR(
"TBranch for " + variableType +
" named '" << branchName.c_str() <<
"' does not exist!");
507 branch->SetAddress(&
object);
508 for (
int i = 0; i < nentries; ++i) {
509 auto entry =
m_tree->LoadTree(i);
511 B2ERROR(
"Error during loading root tree from chain, error code: " << entry);
514 if (currentTreeNumber !=
m_tree->GetTreeNumber()) {
515 currentTreeNumber =
m_tree->GetTreeNumber();
516 branch =
m_tree->GetBranch(branchName.c_str());
517 branch->SetAddress(&
object);
519 branch->GetEntry(entry);
523 m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
529 auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
530 return branch !=
nullptr;
538 if (not variableName.empty()) {
540 m_tree->SetBranchStatus(variableName.c_str(),
true);
541 m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
547 B2ERROR(
"Couldn't find given " << variableType <<
" variable named " << variableName <<
548 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
549 throw std::runtime_error(
"Couldn't find given " + variableType +
" variable named " + variableName +
550 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
560 for (
unsigned int i = 0; i < variableNames.size(); ++i)
565 std::vector<Variable::Manager::VarVariant>& varVariantTargets)
567 for (
unsigned int i = 0; i < variableNames.size(); ++i) {
568 if (std::holds_alternative<double>(varVariantTargets[i]))
570 else if (std::holds_alternative<int>(varVariantTargets[i]))
572 else if (std::holds_alternative<bool>(varVariantTargets[i]))
580 m_tree->SetBranchStatus(
"*",
false);
581 std::string typeName;
586 B2INFO(
"No weight variable provided. The weight will be set to 1.");
591 m_tree->SetBranchStatus(
"__weight__",
true);
597 B2INFO(
"Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
598 "weight variable to an empty string if you don't need it.");
622 typeName =
"feature";
624 typeName =
"spectator";
627 typeName =
"feature";
629 typeName =
"spectator";
642 TBranch* branch =
m_tree->GetBranch(branchName.c_str());
643 TLeaf* leaf = branch->GetLeaf(branchName.c_str());
644 std::string type_name = leaf->GetTypeName();
648 if (type_name ==
"Float_t")
654 if (type_name ==
"Float_t") {
658 B2ERROR(
"There is a mix of float and basf2 variable types (double, int, bool)");
659 }
else if (type_name ==
"Double_t" or type_name ==
"Int_t" or type_name ==
"Bool_t") {
661 B2ERROR(
"There is a mix of float and basf2 variable types (double, int, bool)");
663 if (type_name ==
"Double_t")
665 else if (type_name ==
"Int_t")
667 else if (type_name ==
"Bool_t")
671 B2FATAL(
"Unknown root input type: " << type_name);
672 throw std::runtime_error(
"Unknown root input type: " + type_name);
682 TBranch* branch =
m_tree->GetBranch(branchName.c_str());
683 TLeaf* leaf = branch->GetLeaf(branchName.c_str());
684 std::string type_name = leaf->GetTypeName();
685 if (type_name ==
"Float_t") {
689 B2ERROR(
"There is a mix of float and basf2 variable types (double, int, bool)");
690 }
else if (type_name ==
"Double_t" or type_name ==
"Int_t" or type_name ==
"Bool_t") {
692 B2ERROR(
"There is a mix of float and basf2 variable types (double, int, bool)");
694 if (type_name ==
"Double_t")
696 else if (type_name ==
"Int_t")
698 else if (type_name ==
"Bool_t")
702 B2FATAL(
"Unknown root input type: " << type_name);
703 throw std::runtime_error(
"Unknown root input type: " + type_name);
714 TBranch* branch =
m_tree->GetBranch(branchName.c_str());
715 TLeaf* leaf = branch->GetLeaf(branchName.c_str());
716 std::string target_type_name = leaf->GetTypeName();
717 if (target_type_name ==
"Double_t")
719 else if (target_type_name ==
"Int_t")
721 else if (target_type_name ==
"Bool_t")
724 B2FATAL(
"Input type " << target_type_name <<
" for target variable is not supported");
725 throw std::runtime_error(
"Unsupported target input type: " + target_type_name);
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual std::vector< bool > getSignals()
Returns all is Signals.
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
virtual std::vector< float > getTargets()
Returns all targets.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights()
Returns all weights.
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
bool m_isSignal
Defines if the currently loaded event is signal or background.
float m_weight
Contains the weight of the currently loaded event.
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
float m_target
Contains the target value of the currently loaded event.
General options which are shared by all MVA trainings.
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
int m_signal_class
Signal class which is used as signal in a classification problem.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_treename
Name of the TTree inside the datafile containing the training data.
std::string m_target_variable
Target variable (branch name) defining the target.
std::vector< float > m_weights
weight vector
std::vector< std::vector< float > > m_matrix
Feature matrix.
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float >> &input, const std::vector< std::vector< float >> &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
std::vector< float > m_targets
target vector
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
void setVectorVariableAddress(std::string &variableType, std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
void setTargetRootInputType()
Determines the data type of the target variable and sets it to m_target_data_type.
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
Variable::Manager::VariableDataType m_target_data_type
Data type of target variable.
TChain * m_tree
Pointer to the TChain containing the data.
double m_target_double
Contains the target value of the currently loaded event.
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
int m_target_int
Contains the target value of the currently loaded event.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
double m_weight_double
Contains the weight of the currently loaded event.
bool m_isFloatInputType
Defines the expected datatype in the ROOT file.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
void setScalarVariableAddress(std::string &variableType, std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
void setRootInputType()
Tries to infer the data-type of a root file and sets m_isDoubleInputType.
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
std::vector< float > getVectorFromTTree(std::string &variableType, std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
virtual ~ROOTDataset()
Virtual destructor.
std::vector< Variable::Manager::VarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
bool m_target_bool
Contains the target value of the currently loaded event.
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
std::vector< Variable::Manager::VarVariant > m_input_variant
Contains all feature values of the currently loaded event.
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
Dataset & m_dataset
Reference to the wrapped dataset.
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.
Abstract base class for different kinds of events.