9 #include <mva/interface/Dataset.h> 
   11 #include <framework/utilities/MakeROOTCompatible.h> 
   12 #include <framework/logging/Logger.h> 
   13 #include <framework/io/RootIOUtilities.h> 
   38       double signal_weight_sum = 0;
 
   39       double weight_sum = 0;
 
   46       return signal_weight_sum / weight_sum;
 
   55         B2ERROR(
"Unknown feature named " << feature);
 
   67         B2ERROR(
"Unknown spectator named " << spectator);
 
   80         result[iEvent] = 
m_input[iFeature];
 
  136                                  const std::vector<float>& spectators) : 
Dataset(general_options)
 
  146                                const std::vector<std::vector<float>>& spectators,
 
  147                                const std::vector<float>& targets, 
const std::vector<float>& weights) : 
Dataset(general_options),  m_matrix(input),
 
  148       m_spectator_matrix(spectators),
 
  149       m_targets(targets), m_weights(weights)
 
  153         B2ERROR(
"Feature matrix and target vector need same number of elements in MultiDataset, got " << 
m_targets.size() << 
" and " <<
 
  157         B2ERROR(
"Feature matrix and weight vector need same number of elements in MultiDataset, got " << 
m_weights.size() << 
" and " <<
 
  161         B2ERROR(
"Feature matrix and spectator matrix need same number of elements in MultiDataset, got " << 
m_spectator_matrix.size() <<
 
  193           B2ERROR(
"Couldn't find variable " << v << 
" in GeneralOptions");
 
  194           throw std::runtime_error(
"Couldn't find variable " + v + 
" in GeneralOptions");
 
  202           B2ERROR(
"Couldn't find spectator " << v << 
" in GeneralOptions");
 
  203           throw std::runtime_error(
"Couldn't find spectator " + v + 
" in GeneralOptions");
 
  208       if (events.size() > 0)
 
  213         unsigned int n_events = 0;
 
  215           if (events.size() == 0 or events[iEvent]) {
 
  227       unsigned int index = iEvent;
 
  235       for (
unsigned int iFeature = 0; iFeature < 
m_input.size(); ++iFeature) {
 
  239       for (
unsigned int iSpectator = 0; iSpectator < 
m_spectators.size(); ++iSpectator) {
 
  274                                      Dataset& background_dataset) : 
Dataset(general_options), m_signal_dataset(signal_dataset),
 
  275       m_background_dataset(background_dataset) { }
 
  301       s.insert(s.end(), b.begin(), b.end());
 
  311       s.insert(s.end(), b.begin(), b.end());
 
  323       for (
const auto& variable : general_options.m_variables)
 
  324         for (
const auto& spectator : general_options.m_spectators)
 
  325           if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
 
  326             B2ERROR(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
 
  327             throw std::runtime_error(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
 
  330       std::vector<std::string> filenames;
 
  332         if (std::filesystem::exists(filename)) {
 
  333           filenames.push_back(filename);
 
  336           filenames.insert(filenames.end(), temp.begin(), temp.end());
 
  339       if (filenames.empty()) {
 
  340         B2ERROR(
"Found no valid filenames in GeneralOptions");
 
  341         throw std::runtime_error(
"Found no valid filenames in GeneralOptions");
 
  345       TDirectory* dir = gDirectory;
 
  346       for (
const auto& filename : filenames) {
 
  347         if (not std::filesystem::exists(filename)) {
 
  348           B2ERROR(
"Error given ROOT file does not exist " << filename);
 
  349           throw std::runtime_error(
"Error during open of ROOT file named " + filename);
 
  352         TFile* f = TFile::Open(filename.c_str(), 
"READ");
 
  353         if (!f or f->IsZombie() or not f->IsOpen()) {
 
  354           B2ERROR(
"Error during open of ROOT file named " << filename);
 
  355           throw std::runtime_error(
"Error during open of ROOT file named " + filename);
 
  362       for (
const auto& filename : filenames) {
 
  364         if (!
m_tree->AddFile(filename.c_str(), -1)) {
 
  365           B2ERROR(
"Error during open of ROOT file named " << filename << 
" cannot retrieve tree named " <<
 
  367           throw std::runtime_error(
"Error during open of ROOT file named " + filename + 
" cannot retrieve tree named " +
 
  378       if (std::holds_alternative<double>(variant))
 
  379         return static_cast<float>(std::get<double>(variant));
 
  380       else if (std::holds_alternative<float>(variant))
 
  381         return std::get<float>(variant);
 
  382       else if (std::holds_alternative<int>(variant))
 
  383         return static_cast<float>(std::get<int>(variant));
 
  384       else if (std::holds_alternative<bool>(variant))
 
  385         return static_cast<float>(std::get<bool>(variant));
 
  387         B2FATAL(
"Unsupported variable type");
 
  393       if (
m_tree->GetEntry(event, 0) == 0) {
 
  394         B2ERROR(
"Error during loading entry from chain");
 
  411       if (branchName.empty()) {
 
  412         B2INFO(
"No TBranch name given for weights. Using 1s as default weights.");
 
  414         std::vector<float> values(nentries, 1.);
 
  417       if (branchName == 
"__weight__") {
 
  419           B2INFO(
"No default weight branch with name __weight__ found. Using 1s as default weights.");
 
  421           std::vector<float> values(nentries, 1.);
 
  425       std::string typeLabel = 
"weights";
 
  432         B2ERROR(
"Feature index " << iFeature << 
" is out of bounds of given number of features: " 
  436       std::string typeLabel = 
"features";
 
  443         B2ERROR(
"Spectator index " << iSpectator << 
" is out of bounds of given number of spectators: " 
  448       std::string typeLabel = 
"spectators";
 
  461       if (std::holds_alternative<double>(memberVariableTarget))
 
  462         return getVectorFromTTree(variableType, branchName, std::get<double>(memberVariableTarget));
 
  463       else if (std::holds_alternative<float>(memberVariableTarget))
 
  464         return getVectorFromTTree(variableType, branchName, std::get<float>(memberVariableTarget));
 
  465       else if (std::holds_alternative<int>(memberVariableTarget))
 
  466         return getVectorFromTTree(variableType, branchName, std::get<int>(memberVariableTarget));
 
  467       else if (std::holds_alternative<bool>(memberVariableTarget))
 
  468         return getVectorFromTTree(variableType, branchName, std::get<bool>(memberVariableTarget));
 
  470         B2FATAL(
"Input type of " << variableType << 
" variable " << branchName << 
" is not supported");
 
  475                                                        T& memberVariableTarget)
 
  478       std::vector<float> values(nentries);
 
  482       auto currentTreeNumber = 
m_tree->GetTreeNumber();
 
  483       TBranch* branch = 
m_tree->GetBranch(branchName.c_str());
 
  485         B2ERROR(
"TBranch for " + variableType + 
" named '" << branchName.c_str()  << 
"' does not exist!");
 
  487       branch->SetAddress(&
object);
 
  488       for (
int i = 0; i < nentries; ++i) {
 
  489         auto entry = 
m_tree->LoadTree(i);
 
  491           B2ERROR(
"Error during loading root tree from chain, error code: " << entry);
 
  494         if (currentTreeNumber != 
m_tree->GetTreeNumber()) {
 
  495           currentTreeNumber = 
m_tree->GetTreeNumber();
 
  496           branch = 
m_tree->GetBranch(branchName.c_str());
 
  497           branch->SetAddress(&
object);
 
  499         branch->GetEntry(entry);
 
  503       m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
 
  509       auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
 
  510       return branch != 
nullptr;
 
  518       if (not variableName.empty()) {
 
  520           m_tree->SetBranchStatus(variableName.c_str(), 
true);
 
  521           m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
 
  527             B2ERROR(
"Couldn't find given " << variableType << 
" variable named " << variableName <<
 
  528                     " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
 
  529             throw std::runtime_error(
"Couldn't find given " + variableType + 
" variable named " + variableName +
 
  530                                      " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
 
  539       if (std::holds_alternative<double>(varVariantTarget))
 
  541       else if (std::holds_alternative<float>(varVariantTarget))
 
  543       else if (std::holds_alternative<int>(varVariantTarget))
 
  545       else if (std::holds_alternative<bool>(varVariantTarget))
 
  548         B2FATAL(
"Variable type for branch " << variableName <<  
" not supported!");
 
  555       for (
unsigned int i = 0; i < variableNames.size(); ++i)
 
  561                                                       std::vector<RootDatasetVarVariant>& varVariantTargets)
 
  563       for (
unsigned int i = 0; i < variableNames.size(); ++i) {
 
  571       m_tree->SetBranchStatus(
"*", 
false);
 
  576             m_tree->SetBranchStatus(
"__weight__", 
true);
 
  577             std::string typeLabel_weight = 
"weight";
 
  578             std::string weight_string =  
"__weight__";
 
  584           std::string typeLabel_weight = 
"weight";
 
  589       std::string typeLabel_target = 
"target";
 
  591       std::string typeLabel_feature = 
"feature";
 
  593       std::string typeLabel_spectator = 
"spectator";
 
  600       if (type == 
"Double_t")
 
  601         varVariantTarget = 0.0;
 
  602       else if (type == 
"Float_t")
 
  603         varVariantTarget = 0.0f;
 
  604       else if (type == 
"Int_t")
 
  605         varVariantTarget = 0;
 
  606       else if (type == 
"Bool_t")
 
  607         varVariantTarget = 
false;
 
  609         B2FATAL(
"Unknown root input type: " << type);
 
  610         throw std::runtime_error(
"Unknown root input type: " + type);
 
  620         TBranch* branch = 
m_tree->GetBranch(branch_name.c_str());
 
  621         TLeaf* leaf = branch->GetLeaf(branch_name.c_str());
 
  622         std::string type_name = leaf->GetTypeName();
 
  625         TBranch* branch = 
m_tree->GetBranch(compatible_branch_name.c_str());
 
  626         TLeaf* leaf = branch->GetLeaf(compatible_branch_name.c_str());
 
  627         std::string type_name = leaf->GetTypeName();
 
  652         B2INFO(
"No weight variable provided. The weight will be set to 1.");
 
  656             m_tree->SetBranchStatus(
"__weight__", 
true);
 
  659             B2INFO(
"Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the " 
  660                    "weight variable to an empty string if you don't need it.");
 
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual std::vector< bool > getSignals()
Returns all is Signals.
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
virtual std::vector< float > getTargets()
Returns all targets.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights()
Returns all weights.
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
bool m_isSignal
Defines if the currently loaded event is signal or background.
float m_weight
Contains the weight of the currently loaded event.
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
float m_target
Contains the target value of the currently loaded event.
General options which are shared by all MVA trainings.
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
int m_signal_class
Signal class which is used as signal in a classification problem.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_treename
Name of the TTree inside the datafile containing the training data.
std::string m_target_variable
Target variable (branch name) defining the target.
std::vector< float > m_weights
weight vector
std::vector< std::vector< float > > m_matrix
Feature matrix.
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float >> &input, const std::vector< std::vector< float >> &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
std::vector< float > m_targets
target vector
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
void setScalarVariableAddress(const std::string &variableType, const std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
void setScalarVariableAddressVariant(const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
sets the branch address for a scalar variable to a given target
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
void initialiseVarVariantForBranch(const std::string, RootDatasetVarVariant &)
Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correc...
TChain * m_tree
Pointer to the TChain containing the data.
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
void setVectorVariableAddressVariant(const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
sets the branch address for a vector of VarVariant to a given target
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
std::vector< float > getVectorFromTTree(const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
void initialiseVarVariantType(const std::string, RootDatasetVarVariant &)
Initialises the VarVariant.
std::vector< RootDatasetVarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
RootDatasetVarVariant m_target_variant
Contains the target value of the currently loaded event.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
std::vector< RootDatasetVarVariant > m_input_variant
Contains all feature values of the currently loaded event.
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
void setRootInputType()
Tries to infer the data-type of the spectator and feature variables in a root file.
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
void setVectorVariableAddress(const std::string &variableType, const std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
virtual ~ROOTDataset()
Virtual destructor.
float castVarVariantToFloat(RootDatasetVarVariant &) const
Casts a VarVariant which can contain <double,int,bool,float> to float.
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
std::variant< double, float, int, bool > RootDatasetVarVariant
Typedef for variable types supported by the mva ROOTDataset, can be one of double,...
std::vector< float > getVectorFromTTreeVariant(const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
Returns all values for a specified variableType and branchName.
RootDatasetVarVariant m_weight_variant
Contains the weight of the currently loaded event.
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
Dataset & m_dataset
Reference to the wrapped dataset.
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.
Abstract base class for different kinds of events.