9#include <mva/interface/Dataset.h>
11#include <framework/utilities/MakeROOTCompatible.h>
12#include <framework/logging/Logger.h>
13#include <framework/io/RootIOUtilities.h>
38 double signal_weight_sum = 0;
39 double weight_sum = 0;
46 return signal_weight_sum / weight_sum;
55 B2ERROR(
"Unknown feature named " << feature);
67 B2ERROR(
"Unknown spectator named " << spectator);
80 result[iEvent] =
m_input[iFeature];
136 const std::vector<float>& spectators) :
Dataset(general_options)
146 const std::vector<std::vector<float>>& spectators,
147 const std::vector<float>& targets,
const std::vector<float>& weights) :
Dataset(general_options), m_matrix(input),
148 m_spectator_matrix(spectators),
149 m_targets(targets), m_weights(weights)
153 B2ERROR(
"Feature matrix and target vector need same number of elements in MultiDataset, got " <<
m_targets.size() <<
" and " <<
157 B2ERROR(
"Feature matrix and weight vector need same number of elements in MultiDataset, got " <<
m_weights.size() <<
" and " <<
161 B2ERROR(
"Feature matrix and spectator matrix need same number of elements in MultiDataset, got " <<
m_spectator_matrix.size() <<
193 B2ERROR(
"Couldn't find variable " << v <<
" in GeneralOptions");
194 throw std::runtime_error(
"Couldn't find variable " + v +
" in GeneralOptions");
202 B2ERROR(
"Couldn't find spectator " << v <<
" in GeneralOptions");
203 throw std::runtime_error(
"Couldn't find spectator " + v +
" in GeneralOptions");
208 if (events.size() > 0)
213 unsigned int n_events = 0;
215 if (events.size() == 0 or events[iEvent]) {
227 unsigned int index = iEvent;
235 for (
unsigned int iFeature = 0; iFeature <
m_input.size(); ++iFeature) {
239 for (
unsigned int iSpectator = 0; iSpectator <
m_spectators.size(); ++iSpectator) {
274 Dataset& background_dataset) :
Dataset(general_options), m_signal_dataset(signal_dataset),
275 m_background_dataset(background_dataset) { }
301 s.insert(s.end(), b.begin(), b.end());
311 s.insert(s.end(), b.begin(), b.end());
323 for (
const auto& variable : general_options.m_variables)
324 for (
const auto& spectator : general_options.m_spectators)
325 if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
326 B2ERROR(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
327 throw std::runtime_error(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
330 std::vector<std::string> filenames;
332 if (std::filesystem::exists(filename)) {
333 filenames.push_back(filename);
336 filenames.insert(filenames.end(), temp.begin(), temp.end());
339 if (filenames.empty()) {
340 B2ERROR(
"Found no valid filenames in GeneralOptions");
341 throw std::runtime_error(
"Found no valid filenames in GeneralOptions");
345 TDirectory* dir = gDirectory;
346 for (
const auto& filename : filenames) {
347 TFile* f = TFile::Open(filename.c_str(),
"READ");
348 if (!f or f->IsZombie() or not f->IsOpen()) {
349 B2ERROR(
"Error during open of ROOT file named " << filename);
350 throw std::runtime_error(
"Error during open of ROOT file named " + filename);
357 for (
const auto& filename : filenames) {
359 if (!
m_tree->AddFile(filename.c_str(), -1)) {
360 B2ERROR(
"Error during open of ROOT file named " << filename <<
" cannot retrieve tree named " <<
362 throw std::runtime_error(
"Error during open of ROOT file named " + filename +
" cannot retrieve tree named " +
373 if (std::holds_alternative<double>(variant))
374 return static_cast<float>(std::get<double>(variant));
375 else if (std::holds_alternative<float>(variant))
376 return std::get<float>(variant);
377 else if (std::holds_alternative<int>(variant))
378 return static_cast<float>(std::get<int>(variant));
379 else if (std::holds_alternative<bool>(variant))
380 return static_cast<float>(std::get<bool>(variant));
382 B2FATAL(
"Unsupported variable type");
388 if (
m_tree->GetEntry(event, 0) == 0) {
389 B2ERROR(
"Error during loading entry from chain");
406 if (branchName.empty()) {
407 B2INFO(
"No TBranch name given for weights. Using 1s as default weights.");
409 std::vector<float> values(nentries, 1.);
412 if (branchName ==
"__weight__") {
414 B2INFO(
"No default weight branch with name __weight__ found. Using 1s as default weights.");
416 std::vector<float> values(nentries, 1.);
420 std::string typeLabel =
"weights";
427 B2ERROR(
"Feature index " << iFeature <<
" is out of bounds of given number of features: "
431 std::string typeLabel =
"features";
438 B2ERROR(
"Spectator index " << iSpectator <<
" is out of bounds of given number of spectators: "
443 std::string typeLabel =
"spectators";
456 if (std::holds_alternative<double>(memberVariableTarget))
457 return getVectorFromTTree(variableType, branchName, std::get<double>(memberVariableTarget));
458 else if (std::holds_alternative<float>(memberVariableTarget))
459 return getVectorFromTTree(variableType, branchName, std::get<float>(memberVariableTarget));
460 else if (std::holds_alternative<int>(memberVariableTarget))
461 return getVectorFromTTree(variableType, branchName, std::get<int>(memberVariableTarget));
462 else if (std::holds_alternative<bool>(memberVariableTarget))
463 return getVectorFromTTree(variableType, branchName, std::get<bool>(memberVariableTarget));
465 B2FATAL(
"Input type of " << variableType <<
" variable " << branchName <<
" is not supported");
470 T& memberVariableTarget)
473 std::vector<float> values(nentries);
477 auto currentTreeNumber =
m_tree->GetTreeNumber();
478 TBranch* branch =
m_tree->GetBranch(branchName.c_str());
480 B2ERROR(
"TBranch for " + variableType +
" named '" << branchName.c_str() <<
"' does not exist!");
482 branch->SetAddress(&
object);
483 for (
int i = 0; i < nentries; ++i) {
484 auto entry =
m_tree->LoadTree(i);
486 B2ERROR(
"Error during loading root tree from chain, error code: " << entry);
489 if (currentTreeNumber !=
m_tree->GetTreeNumber()) {
490 currentTreeNumber =
m_tree->GetTreeNumber();
491 branch =
m_tree->GetBranch(branchName.c_str());
492 branch->SetAddress(&
object);
494 branch->GetEntry(entry);
498 m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
504 auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
505 return branch !=
nullptr;
513 if (not variableName.empty()) {
515 m_tree->SetBranchStatus(variableName.c_str(),
true);
516 m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
522 B2ERROR(
"Couldn't find given " << variableType <<
" variable named " << variableName <<
523 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
524 throw std::runtime_error(
"Couldn't find given " + variableType +
" variable named " + variableName +
525 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
534 if (std::holds_alternative<double>(varVariantTarget))
536 else if (std::holds_alternative<float>(varVariantTarget))
538 else if (std::holds_alternative<int>(varVariantTarget))
540 else if (std::holds_alternative<bool>(varVariantTarget))
543 B2FATAL(
"Variable type for branch " << variableName <<
" not supported!");
550 for (
unsigned int i = 0; i < variableNames.size(); ++i)
556 std::vector<RootDatasetVarVariant>& varVariantTargets)
558 for (
unsigned int i = 0; i < variableNames.size(); ++i) {
566 m_tree->SetBranchStatus(
"*",
false);
571 m_tree->SetBranchStatus(
"__weight__",
true);
572 std::string typeLabel_weight =
"weight";
573 std::string weight_string =
"__weight__";
579 std::string typeLabel_weight =
"weight";
584 std::string typeLabel_target =
"target";
586 std::string typeLabel_feature =
"feature";
588 std::string typeLabel_spectator =
"spectator";
595 if (type ==
"Double_t")
596 varVariantTarget = 0.0;
597 else if (type ==
"Float_t")
598 varVariantTarget = 0.0f;
599 else if (type ==
"Int_t")
600 varVariantTarget = 0;
601 else if (type ==
"Bool_t")
602 varVariantTarget =
false;
604 B2FATAL(
"Unknown root input type: " << type);
605 throw std::runtime_error(
"Unknown root input type: " + type);
615 TBranch* branch =
m_tree->GetBranch(branch_name.c_str());
616 TLeaf* leaf = branch->GetLeaf(branch_name.c_str());
617 std::string type_name = leaf->GetTypeName();
620 TBranch* branch =
m_tree->GetBranch(compatible_branch_name.c_str());
621 TLeaf* leaf = branch->GetLeaf(compatible_branch_name.c_str());
622 std::string type_name = leaf->GetTypeName();
647 B2INFO(
"No weight variable provided. The weight will be set to 1.");
651 m_tree->SetBranchStatus(
"__weight__",
true);
654 B2INFO(
"Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
655 "weight variable to an empty string if you don't need it.");
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual std::vector< bool > getSignals()
Returns all is Signals.
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
virtual std::vector< float > getTargets()
Returns all targets.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights()
Returns all weights.
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
bool m_isSignal
Defines if the currently loaded event is signal or background.
float m_weight
Contains the weight of the currently loaded event.
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
float m_target
Contains the target value of the currently loaded event.
General options which are shared by all MVA trainings.
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
int m_signal_class
Signal class which is used as signal in a classification problem.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_treename
Name of the TTree inside the datafile containing the training data.
std::string m_target_variable
Target variable (branch name) defining the target.
std::vector< float > m_weights
weight vector
std::vector< std::vector< float > > m_matrix
Feature matrix.
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
std::vector< float > m_targets
target vector
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float > > &input, const std::vector< std::vector< float > > &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
void setScalarVariableAddress(const std::string &variableType, const std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
void setScalarVariableAddressVariant(const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
sets the branch address for a scalar variable to a given target
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
void initialiseVarVariantForBranch(const std::string, RootDatasetVarVariant &)
Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correc...
TChain * m_tree
Pointer to the TChain containing the data.
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
void setVectorVariableAddressVariant(const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
sets the branch address for a vector of VarVariant to a given target
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
std::vector< float > getVectorFromTTree(const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
void initialiseVarVariantType(const std::string, RootDatasetVarVariant &)
Initialises the VarVariant.
std::vector< RootDatasetVarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
RootDatasetVarVariant m_target_variant
Contains the target value of the currently loaded event.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
std::vector< RootDatasetVarVariant > m_input_variant
Contains all feature values of the currently loaded event.
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
void setRootInputType()
Tries to infer the data-type of the spectator and feature variables in a root file.
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
void setVectorVariableAddress(const std::string &variableType, const std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
virtual ~ROOTDataset()
Virtual destructor.
float castVarVariantToFloat(RootDatasetVarVariant &) const
Casts a VarVariant which can contain <double,int,bool,float> to float.
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
std::variant< double, float, int, bool > RootDatasetVarVariant
Typedef for variable types supported by the mva ROOTDataset, can be one of double,...
std::vector< float > getVectorFromTTreeVariant(const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
Returns all values for a specified variableType and branchName.
RootDatasetVarVariant m_weight_variant
Contains the weight of the currently loaded event.
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
Dataset & m_dataset
Reference to the wrapped dataset.
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.
Abstract base class for different kinds of events.