9#include <mva/interface/Dataset.h> 
   11#include <framework/utilities/MakeROOTCompatible.h> 
   12#include <framework/logging/Logger.h> 
   13#include <framework/io/RootIOUtilities.h> 
   38      double signal_weight_sum = 0;
 
   39      double weight_sum = 0;
 
   46      return signal_weight_sum / weight_sum;
 
 
   55        B2ERROR(
"Unknown feature named " << feature);
 
 
   67        B2ERROR(
"Unknown spectator named " << spectator);
 
 
   80        result[iEvent] = 
m_input[iFeature];
 
 
  136                                 const std::vector<float>& spectators) : 
Dataset(general_options)
 
 
  146                               const std::vector<std::vector<float>>& spectators,
 
  147                               const std::vector<float>& targets, 
const std::vector<float>& weights) : 
Dataset(general_options),  
m_matrix(input),
 
  153        B2ERROR(
"Feature matrix and target vector need same number of elements in MultiDataset, got " << m_targets.size() << 
" and " <<
 
  157        B2ERROR(
"Feature matrix and weight vector need same number of elements in MultiDataset, got " << m_weights.size() << 
" and " <<
 
  161        B2ERROR(
"Feature matrix and spectator matrix need same number of elements in MultiDataset, got " << m_spectator_matrix.size() <<
 
 
  191        auto it = std::find(m_dataset.m_general_options.m_variables.begin(), m_dataset.m_general_options.m_variables.end(), v);
 
  192        if (it == m_dataset.m_general_options.m_variables.end()) {
 
  193          B2ERROR(
"Couldn't find variable " << v << 
" in GeneralOptions");
 
  194          throw std::runtime_error(
"Couldn't find variable " + v + 
" in GeneralOptions");
 
  199      for (
auto& v : m_general_options.m_spectators) {
 
  200        auto it = std::find(m_dataset.m_general_options.m_spectators.begin(), m_dataset.m_general_options.m_spectators.end(), v);
 
  201        if (it == m_dataset.m_general_options.m_spectators.end()) {
 
  202          B2ERROR(
"Couldn't find spectator " << v << 
" in GeneralOptions");
 
  203          throw std::runtime_error(
"Couldn't find spectator " + v + 
" in GeneralOptions");
 
  205        m_spectator_indices.push_back(it - m_dataset.m_general_options.m_spectators.begin());
 
  208      if (events.size() > 0)
 
  209        m_use_event_indices = 
true;
 
  211      if (m_use_event_indices) {
 
  212        m_event_indices.resize(dataset.getNumberOfEvents());
 
  213        unsigned int n_events = 0;
 
  214        for (
unsigned int iEvent = 0; iEvent < dataset.getNumberOfEvents(); ++iEvent) {
 
  215          if (events.size() == 0 or events[iEvent]) {
 
  216            m_event_indices[n_events] = iEvent;
 
  220        m_event_indices.resize(n_events);
 
 
  227      unsigned int index = iEvent;
 
  235      for (
unsigned int iFeature = 0; iFeature < 
m_input.size(); ++iFeature) {
 
  239      for (
unsigned int iSpectator = 0; iSpectator < 
m_spectators.size(); ++iSpectator) {
 
 
  301      s.insert(s.end(), b.begin(), b.end());
 
 
  311      s.insert(s.end(), b.begin(), b.end());
 
 
  323      for (
const auto& variable : general_options.m_variables)
 
  324        for (
const auto& spectator : general_options.m_spectators)
 
  325          if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
 
  326            B2ERROR(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
 
  327            throw std::runtime_error(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
 
  330      std::vector<std::string> filenames;
 
  332        if (std::filesystem::exists(filename)) {
 
  333          filenames.push_back(filename);
 
  335          auto temp = RootIOUtilities::expandWordExpansions(m_general_options.m_datafiles);
 
  336          filenames.insert(filenames.end(), temp.begin(), temp.end());
 
  339      if (filenames.empty()) {
 
  340        B2ERROR(
"Found no valid filenames in GeneralOptions");
 
  341        throw std::runtime_error(
"Found no valid filenames in GeneralOptions");
 
  345      TDirectory* dir = gDirectory;
 
  346      for (
const auto& filename : filenames) {
 
  347        TFile* f = TFile::Open(filename.c_str(), 
"READ");
 
  348        if (!f or f->IsZombie() or not f->IsOpen()) {
 
  349          B2ERROR(
"Error during open of ROOT file named " << filename);
 
  350          throw std::runtime_error(
"Error during open of ROOT file named " + filename);
 
  356      m_tree = 
new TChain(m_general_options.m_treename.c_str());
 
  357      for (
const auto& filename : filenames) {
 
  359        if (!m_tree->AddFile(filename.c_str(), -1)) {
 
  360          B2ERROR(
"Error during open of ROOT file named " << filename << 
" cannot retrieve tree named " <<
 
  361                  m_general_options.m_treename);
 
  362          throw std::runtime_error(
"Error during open of ROOT file named " + filename + 
" cannot retrieve tree named " +
 
  363                                   m_general_options.m_treename);
 
  367      setBranchAddresses();
 
 
  373      if (std::holds_alternative<double>(variant))
 
  374        return static_cast<float>(std::get<double>(variant));
 
  375      else if (std::holds_alternative<float>(variant))
 
  376        return std::get<float>(variant);
 
  377      else if (std::holds_alternative<int>(variant))
 
  378        return static_cast<float>(std::get<int>(variant));
 
  379      else if (std::holds_alternative<bool>(variant))
 
  380        return static_cast<float>(std::get<bool>(variant));
 
  382        B2FATAL(
"Unsupported variable type");
 
 
  388      if (
m_tree->GetEntry(event, 0) == 0) {
 
  389        B2ERROR(
"Error during loading entry from chain");
 
 
  406      if (branchName.empty()) {
 
  407        B2INFO(
"No TBranch name given for weights. Using 1s as default weights.");
 
  409        std::vector<float> values(nentries, 1.);
 
  412      if (branchName == 
"__weight__") {
 
  414          B2INFO(
"No default weight branch with name __weight__ found. Using 1s as default weights.");
 
  416          std::vector<float> values(nentries, 1.);
 
  420      std::string typeLabel = 
"weights";
 
 
  427        B2ERROR(
"Feature index " << iFeature << 
" is out of bounds of given number of features: " 
  431      std::string typeLabel = 
"features";
 
 
  438        B2ERROR(
"Spectator index " << iSpectator << 
" is out of bounds of given number of spectators: " 
  443      std::string typeLabel = 
"spectators";
 
 
  456      if (std::holds_alternative<double>(memberVariableTarget))
 
  457        return getVectorFromTTree(variableType, branchName, std::get<double>(memberVariableTarget));
 
  458      else if (std::holds_alternative<float>(memberVariableTarget))
 
  459        return getVectorFromTTree(variableType, branchName, std::get<float>(memberVariableTarget));
 
  460      else if (std::holds_alternative<int>(memberVariableTarget))
 
  461        return getVectorFromTTree(variableType, branchName, std::get<int>(memberVariableTarget));
 
  462      else if (std::holds_alternative<bool>(memberVariableTarget))
 
  463        return getVectorFromTTree(variableType, branchName, std::get<bool>(memberVariableTarget));
 
  465        B2FATAL(
"Input type of " << variableType << 
" variable " << branchName << 
" is not supported");
 
 
  470                                                       T& memberVariableTarget)
 
  473      std::vector<float> values(nentries);
 
  477      auto currentTreeNumber = 
m_tree->GetTreeNumber();
 
  478      TBranch* branch = 
m_tree->GetBranch(branchName.c_str());
 
  480        B2ERROR(
"TBranch for " + variableType + 
" named '" << branchName.c_str()  << 
"' does not exist!");
 
  482      branch->SetAddress(&
object);
 
  483      for (
int i = 0; i < nentries; ++i) {
 
  484        auto entry = 
m_tree->LoadTree(i);
 
  486          B2ERROR(
"Error during loading root tree from chain, error code: " << entry);
 
  489        if (currentTreeNumber != 
m_tree->GetTreeNumber()) {
 
  490          currentTreeNumber = 
m_tree->GetTreeNumber();
 
  491          branch = 
m_tree->GetBranch(branchName.c_str());
 
  492          branch->SetAddress(&
object);
 
  494        branch->GetEntry(entry);
 
  498      m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
 
 
  504      auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
 
  505      return branch != 
nullptr;
 
 
  513      if (not variableName.empty()) {
 
  515          m_tree->SetBranchStatus(variableName.c_str(), 
true);
 
  516          m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
 
  522            B2ERROR(
"Couldn't find given " << variableType << 
" variable named " << variableName <<
 
  523                    " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
 
  524            throw std::runtime_error(
"Couldn't find given " + variableType + 
" variable named " + variableName +
 
  525                                     " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
 
 
  534      if (std::holds_alternative<double>(varVariantTarget))
 
  536      else if (std::holds_alternative<float>(varVariantTarget))
 
  538      else if (std::holds_alternative<int>(varVariantTarget))
 
  540      else if (std::holds_alternative<bool>(varVariantTarget))
 
  543        B2FATAL(
"Variable type for branch " << variableName <<  
" not supported!");
 
 
  550      for (
unsigned int i = 0; i < variableNames.size(); ++i)
 
 
  556                                                      std::vector<RootDatasetVarVariant>& varVariantTargets)
 
  558      for (
unsigned int i = 0; i < variableNames.size(); ++i) {
 
 
  566      m_tree->SetBranchStatus(
"*", 
false);
 
  571            m_tree->SetBranchStatus(
"__weight__", 
true);
 
  572            std::string typeLabel_weight = 
"weight";
 
  573            std::string weight_string =  
"__weight__";
 
  579          std::string typeLabel_weight = 
"weight";
 
  584      std::string typeLabel_target = 
"target";
 
  586      std::string typeLabel_feature = 
"feature";
 
  588      std::string typeLabel_spectator = 
"spectator";
 
 
  595      if (type == 
"Double_t")
 
  596        varVariantTarget = 0.0;
 
  597      else if (type == 
"Float_t")
 
  598        varVariantTarget = 0.0f;
 
  599      else if (type == 
"Int_t")
 
  600        varVariantTarget = 0;
 
  601      else if (type == 
"Bool_t")
 
  602        varVariantTarget = 
false;
 
  604        B2FATAL(
"Unknown root input type: " << type);
 
  605        throw std::runtime_error(
"Unknown root input type: " + type);
 
 
  615        TBranch* branch = 
m_tree->GetBranch(branch_name.c_str());
 
  616        TLeaf* leaf = branch->GetLeaf(branch_name.c_str());
 
  617        std::string type_name = leaf->GetTypeName();
 
  620        TBranch* branch = 
m_tree->GetBranch(compatible_branch_name.c_str());
 
  621        TLeaf* leaf = branch->GetLeaf(compatible_branch_name.c_str());
 
  622        std::string type_name = leaf->GetTypeName();
 
 
  647        B2INFO(
"No weight variable provided. The weight will be set to 1.");
 
  651            m_tree->SetBranchStatus(
"__weight__", 
true);
 
  654            B2INFO(
"Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the " 
  655                   "weight variable to an empty string if you don't need it.");
 
 
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual std::vector< bool > getSignals()
Returns all is Signals.
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
virtual std::vector< float > getTargets()
Returns all targets.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights()
Returns all weights.
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
bool m_isSignal
Defines if the currently loaded event is signal or background.
float m_weight
Contains the weight of the currently loaded event.
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
float m_target
Contains the target value of the currently loaded event.
General options which are shared by all MVA trainings.
std::vector< float > m_weights
weight vector
std::vector< std::vector< float > > m_matrix
Feature matrix.
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
std::vector< float > m_targets
target vector
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float > > &input, const std::vector< std::vector< float > > &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
void setScalarVariableAddress(const std::string &variableType, const std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
void setScalarVariableAddressVariant(const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
sets the branch address for a scalar variable to a given target
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
void initialiseVarVariantForBranch(const std::string, RootDatasetVarVariant &)
Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correc...
TChain * m_tree
Pointer to the TChain containing the data.
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
void setVectorVariableAddressVariant(const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
sets the branch address for a vector of VarVariant to a given target
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
std::vector< float > getVectorFromTTree(const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
void initialiseVarVariantType(const std::string, RootDatasetVarVariant &)
Initialises the VarVariant.
std::vector< RootDatasetVarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
RootDatasetVarVariant m_target_variant
Contains the target value of the currently loaded event.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
std::vector< RootDatasetVarVariant > m_input_variant
Contains all feature values of the currently loaded event.
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
void setRootInputType()
Tries to infer the data-type of the spectator and feature variables in a root file.
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
void setVectorVariableAddress(const std::string &variableType, const std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
virtual ~ROOTDataset()
Virtual destructor.
float castVarVariantToFloat(RootDatasetVarVariant &) const
Casts a VarVariant which can contain <double,int,bool,float> to float.
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
std::variant< double, float, int, bool > RootDatasetVarVariant
Typedef for variable types supported by the mva ROOTDataset, can be one of double,...
std::vector< float > getVectorFromTTreeVariant(const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
Returns all values for a specified variableType and branchName.
RootDatasetVarVariant m_weight_variant
Contains the weight of the currently loaded event.
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
Dataset & m_dataset
Reference to the wrapped dataset.
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
Abstract base class for different kinds of events.