9#include <mva/interface/Dataset.h>
11#include <framework/utilities/MakeROOTCompatible.h>
12#include <framework/logging/Logger.h>
13#include <framework/io/RootIOUtilities.h>
38 double signal_weight_sum = 0;
39 double weight_sum = 0;
46 return signal_weight_sum / weight_sum;
55 B2ERROR(
"Unknown feature named " << feature);
67 B2ERROR(
"Unknown spectator named " << spectator);
80 result[iEvent] =
m_input[iFeature];
136 const std::vector<float>& spectators) :
Dataset(general_options)
146 const std::vector<std::vector<float>>& spectators,
147 const std::vector<float>& targets,
const std::vector<float>& weights) :
Dataset(general_options), m_matrix(input),
148 m_spectator_matrix(spectators),
149 m_targets(targets), m_weights(weights)
153 B2ERROR(
"Feature matrix and target vector need same number of elements in MultiDataset, got " <<
m_targets.size() <<
" and " <<
157 B2ERROR(
"Feature matrix and weight vector need same number of elements in MultiDataset, got " <<
m_weights.size() <<
" and " <<
161 B2ERROR(
"Feature matrix and spectator matrix need same number of elements in MultiDataset, got " <<
m_spectator_matrix.size() <<
193 B2ERROR(
"Couldn't find variable " << v <<
" in GeneralOptions");
194 throw std::runtime_error(
"Couldn't find variable " + v +
" in GeneralOptions");
202 B2ERROR(
"Couldn't find spectator " << v <<
" in GeneralOptions");
203 throw std::runtime_error(
"Couldn't find spectator " + v +
" in GeneralOptions");
208 if (events.size() > 0)
213 unsigned int n_events = 0;
215 if (events.size() == 0 or events[iEvent]) {
227 unsigned int index = iEvent;
235 for (
unsigned int iFeature = 0; iFeature <
m_input.size(); ++iFeature) {
239 for (
unsigned int iSpectator = 0; iSpectator <
m_spectators.size(); ++iSpectator) {
274 Dataset& background_dataset) :
Dataset(general_options), m_signal_dataset(signal_dataset),
275 m_background_dataset(background_dataset) { }
301 s.insert(s.end(), b.begin(), b.end());
311 s.insert(s.end(), b.begin(), b.end());
323 for (
const auto& variable : general_options.m_variables)
324 for (
const auto& spectator : general_options.m_spectators)
325 if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
326 B2ERROR(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
327 throw std::runtime_error(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
330 std::vector<std::string> filenames;
332 if (std::filesystem::exists(filename)) {
333 filenames.push_back(filename);
336 filenames.insert(filenames.end(), temp.begin(), temp.end());
339 if (filenames.empty()) {
340 B2ERROR(
"Found no valid filenames in GeneralOptions");
341 throw std::runtime_error(
"Found no valid filenames in GeneralOptions");
345 TDirectory* dir = gDirectory;
346 for (
const auto& filename : filenames) {
347 if (not std::filesystem::exists(filename)) {
348 B2ERROR(
"Error given ROOT file does not exist " << filename);
349 throw std::runtime_error(
"Error during open of ROOT file named " + filename);
352 TFile* f = TFile::Open(filename.c_str(),
"READ");
353 if (!f or f->IsZombie() or not f->IsOpen()) {
354 B2ERROR(
"Error during open of ROOT file named " << filename);
355 throw std::runtime_error(
"Error during open of ROOT file named " + filename);
362 for (
const auto& filename : filenames) {
364 if (!
m_tree->AddFile(filename.c_str(), -1)) {
365 B2ERROR(
"Error during open of ROOT file named " << filename <<
" cannot retrieve tree named " <<
367 throw std::runtime_error(
"Error during open of ROOT file named " + filename +
" cannot retrieve tree named " +
378 if (std::holds_alternative<double>(variant))
379 return static_cast<float>(std::get<double>(variant));
380 else if (std::holds_alternative<float>(variant))
381 return std::get<float>(variant);
382 else if (std::holds_alternative<int>(variant))
383 return static_cast<float>(std::get<int>(variant));
384 else if (std::holds_alternative<bool>(variant))
385 return static_cast<float>(std::get<bool>(variant));
387 B2FATAL(
"Unsupported variable type");
393 if (
m_tree->GetEntry(event, 0) == 0) {
394 B2ERROR(
"Error during loading entry from chain");
411 if (branchName.empty()) {
412 B2INFO(
"No TBranch name given for weights. Using 1s as default weights.");
414 std::vector<float> values(nentries, 1.);
417 if (branchName ==
"__weight__") {
419 B2INFO(
"No default weight branch with name __weight__ found. Using 1s as default weights.");
421 std::vector<float> values(nentries, 1.);
425 std::string typeLabel =
"weights";
432 B2ERROR(
"Feature index " << iFeature <<
" is out of bounds of given number of features: "
436 std::string typeLabel =
"features";
443 B2ERROR(
"Spectator index " << iSpectator <<
" is out of bounds of given number of spectators: "
448 std::string typeLabel =
"spectators";
461 if (std::holds_alternative<double>(memberVariableTarget))
462 return getVectorFromTTree(variableType, branchName, std::get<double>(memberVariableTarget));
463 else if (std::holds_alternative<float>(memberVariableTarget))
464 return getVectorFromTTree(variableType, branchName, std::get<float>(memberVariableTarget));
465 else if (std::holds_alternative<int>(memberVariableTarget))
466 return getVectorFromTTree(variableType, branchName, std::get<int>(memberVariableTarget));
467 else if (std::holds_alternative<bool>(memberVariableTarget))
468 return getVectorFromTTree(variableType, branchName, std::get<bool>(memberVariableTarget));
470 B2FATAL(
"Input type of " << variableType <<
" variable " << branchName <<
" is not supported");
475 T& memberVariableTarget)
478 std::vector<float> values(nentries);
482 auto currentTreeNumber =
m_tree->GetTreeNumber();
483 TBranch* branch =
m_tree->GetBranch(branchName.c_str());
485 B2ERROR(
"TBranch for " + variableType +
" named '" << branchName.c_str() <<
"' does not exist!");
487 branch->SetAddress(&
object);
488 for (
int i = 0; i < nentries; ++i) {
489 auto entry =
m_tree->LoadTree(i);
491 B2ERROR(
"Error during loading root tree from chain, error code: " << entry);
494 if (currentTreeNumber !=
m_tree->GetTreeNumber()) {
495 currentTreeNumber =
m_tree->GetTreeNumber();
496 branch =
m_tree->GetBranch(branchName.c_str());
497 branch->SetAddress(&
object);
499 branch->GetEntry(entry);
503 m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
509 auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
510 return branch !=
nullptr;
518 if (not variableName.empty()) {
520 m_tree->SetBranchStatus(variableName.c_str(),
true);
521 m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
527 B2ERROR(
"Couldn't find given " << variableType <<
" variable named " << variableName <<
528 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
529 throw std::runtime_error(
"Couldn't find given " + variableType +
" variable named " + variableName +
530 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
539 if (std::holds_alternative<double>(varVariantTarget))
541 else if (std::holds_alternative<float>(varVariantTarget))
543 else if (std::holds_alternative<int>(varVariantTarget))
545 else if (std::holds_alternative<bool>(varVariantTarget))
548 B2FATAL(
"Variable type for branch " << variableName <<
" not supported!");
555 for (
unsigned int i = 0; i < variableNames.size(); ++i)
561 std::vector<RootDatasetVarVariant>& varVariantTargets)
563 for (
unsigned int i = 0; i < variableNames.size(); ++i) {
571 m_tree->SetBranchStatus(
"*",
false);
576 m_tree->SetBranchStatus(
"__weight__",
true);
577 std::string typeLabel_weight =
"weight";
578 std::string weight_string =
"__weight__";
584 std::string typeLabel_weight =
"weight";
589 std::string typeLabel_target =
"target";
591 std::string typeLabel_feature =
"feature";
593 std::string typeLabel_spectator =
"spectator";
600 if (type ==
"Double_t")
601 varVariantTarget = 0.0;
602 else if (type ==
"Float_t")
603 varVariantTarget = 0.0f;
604 else if (type ==
"Int_t")
605 varVariantTarget = 0;
606 else if (type ==
"Bool_t")
607 varVariantTarget =
false;
609 B2FATAL(
"Unknown root input type: " << type);
610 throw std::runtime_error(
"Unknown root input type: " + type);
620 TBranch* branch =
m_tree->GetBranch(branch_name.c_str());
621 TLeaf* leaf = branch->GetLeaf(branch_name.c_str());
622 std::string type_name = leaf->GetTypeName();
625 TBranch* branch =
m_tree->GetBranch(compatible_branch_name.c_str());
626 TLeaf* leaf = branch->GetLeaf(compatible_branch_name.c_str());
627 std::string type_name = leaf->GetTypeName();
652 B2INFO(
"No weight variable provided. The weight will be set to 1.");
656 m_tree->SetBranchStatus(
"__weight__",
true);
659 B2INFO(
"Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
660 "weight variable to an empty string if you don't need it.");
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual std::vector< bool > getSignals()
Returns all is Signals.
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
virtual std::vector< float > getTargets()
Returns all targets.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights()
Returns all weights.
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
bool m_isSignal
Defines if the currently loaded event is signal or background.
float m_weight
Contains the weight of the currently loaded event.
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
float m_target
Contains the target value of the currently loaded event.
General options which are shared by all MVA trainings.
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
int m_signal_class
Signal class which is used as signal in a classification problem.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_treename
Name of the TTree inside the datafile containing the training data.
std::string m_target_variable
Target variable (branch name) defining the target.
std::vector< float > m_weights
weight vector
std::vector< std::vector< float > > m_matrix
Feature matrix.
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
std::vector< float > m_targets
target vector
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float > > &input, const std::vector< std::vector< float > > &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
void setScalarVariableAddress(const std::string &variableType, const std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
void setScalarVariableAddressVariant(const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
sets the branch address for a scalar variable to a given target
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
void initialiseVarVariantForBranch(const std::string, RootDatasetVarVariant &)
Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correc...
TChain * m_tree
Pointer to the TChain containing the data.
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
void setVectorVariableAddressVariant(const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
sets the branch address for a vector of VarVariant to a given target
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
std::vector< float > getVectorFromTTree(const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
void initialiseVarVariantType(const std::string, RootDatasetVarVariant &)
Initialises the VarVariant.
std::vector< RootDatasetVarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
RootDatasetVarVariant m_target_variant
Contains the target value of the currently loaded event.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
std::vector< RootDatasetVarVariant > m_input_variant
Contains all feature values of the currently loaded event.
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
void setRootInputType()
Tries to infer the data-type of the spectator and feature variables in a root file.
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
void setVectorVariableAddress(const std::string &variableType, const std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
virtual ~ROOTDataset()
Virtual destructor.
float castVarVariantToFloat(RootDatasetVarVariant &) const
Casts a VarVariant which can contain <double,int,bool,float> to float.
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
std::variant< double, float, int, bool > RootDatasetVarVariant
Typedef for variable types supported by the mva ROOTDataset, can be one of double,...
std::vector< float > getVectorFromTTreeVariant(const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
Returns all values for a specified variableType and branchName.
RootDatasetVarVariant m_weight_variant
Contains the weight of the currently loaded event.
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
Dataset & m_dataset
Reference to the wrapped dataset.
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.
Abstract base class for different kinds of events.