9#include <mva/interface/Dataset.h>
11#include <framework/utilities/MakeROOTCompatible.h>
12#include <framework/logging/Logger.h>
13#include <framework/io/RootIOUtilities.h>
38 double signal_weight_sum = 0;
39 double weight_sum = 0;
46 return signal_weight_sum / weight_sum;
55 B2ERROR(
"Unknown feature named " << feature);
67 B2ERROR(
"Unknown spectator named " << spectator);
80 result[iEvent] =
m_input[iFeature];
136 const std::vector<float>& spectators) :
Dataset(general_options)
146 const std::vector<std::vector<float>>& spectators,
147 const std::vector<float>& targets,
const std::vector<float>& weights) :
Dataset(general_options),
m_matrix(input),
153 B2ERROR(
"Feature matrix and target vector need same number of elements in MultiDataset, got " << m_targets.size() <<
" and " <<
157 B2ERROR(
"Feature matrix and weight vector need same number of elements in MultiDataset, got " << m_weights.size() <<
" and " <<
161 B2ERROR(
"Feature matrix and spectator matrix need same number of elements in MultiDataset, got " << m_spectator_matrix.size() <<
191 auto it = std::find(m_dataset.m_general_options.m_variables.begin(), m_dataset.m_general_options.m_variables.end(), v);
192 if (it == m_dataset.m_general_options.m_variables.end()) {
193 B2ERROR(
"Couldn't find variable " << v <<
" in GeneralOptions");
194 throw std::runtime_error(
"Couldn't find variable " + v +
" in GeneralOptions");
199 for (
auto& v : m_general_options.m_spectators) {
200 auto it = std::find(m_dataset.m_general_options.m_spectators.begin(), m_dataset.m_general_options.m_spectators.end(), v);
201 if (it == m_dataset.m_general_options.m_spectators.end()) {
202 B2ERROR(
"Couldn't find spectator " << v <<
" in GeneralOptions");
203 throw std::runtime_error(
"Couldn't find spectator " + v +
" in GeneralOptions");
205 m_spectator_indices.push_back(it - m_dataset.m_general_options.m_spectators.begin());
208 if (events.size() > 0)
209 m_use_event_indices =
true;
211 if (m_use_event_indices) {
212 m_event_indices.resize(dataset.getNumberOfEvents());
213 unsigned int n_events = 0;
214 for (
unsigned int iEvent = 0; iEvent < dataset.getNumberOfEvents(); ++iEvent) {
215 if (events.size() == 0 or events[iEvent]) {
216 m_event_indices[n_events] = iEvent;
220 m_event_indices.resize(n_events);
227 unsigned int index = iEvent;
235 for (
unsigned int iFeature = 0; iFeature <
m_input.size(); ++iFeature) {
239 for (
unsigned int iSpectator = 0; iSpectator <
m_spectators.size(); ++iSpectator) {
301 s.insert(s.end(), b.begin(), b.end());
311 s.insert(s.end(), b.begin(), b.end());
323 for (
const auto& variable : general_options.m_variables)
324 for (
const auto& spectator : general_options.m_spectators)
325 if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
326 B2ERROR(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
327 throw std::runtime_error(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
330 std::vector<std::string> filenames;
332 if (std::filesystem::exists(filename)) {
333 filenames.push_back(filename);
335 auto temp = RootIOUtilities::expandWordExpansions(m_general_options.m_datafiles);
336 filenames.insert(filenames.end(), temp.begin(), temp.end());
339 if (filenames.empty()) {
340 B2ERROR(
"Found no valid filenames in GeneralOptions");
341 throw std::runtime_error(
"Found no valid filenames in GeneralOptions");
345 TDirectory* dir = gDirectory;
346 for (
const auto& filename : filenames) {
347 TFile* f = TFile::Open(filename.c_str(),
"READ");
348 if (!f or f->IsZombie() or not f->IsOpen()) {
349 B2ERROR(
"Error during open of ROOT file named " << filename);
350 throw std::runtime_error(
"Error during open of ROOT file named " + filename);
356 m_tree =
new TChain(m_general_options.m_treename.c_str());
357 for (
const auto& filename : filenames) {
359 if (!m_tree->AddFile(filename.c_str(), -1)) {
360 B2ERROR(
"Error during open of ROOT file named " << filename <<
" cannot retrieve tree named " <<
361 m_general_options.m_treename);
362 throw std::runtime_error(
"Error during open of ROOT file named " + filename +
" cannot retrieve tree named " +
363 m_general_options.m_treename);
367 setBranchAddresses();
373 if (std::holds_alternative<double>(variant))
374 return static_cast<float>(std::get<double>(variant));
375 else if (std::holds_alternative<float>(variant))
376 return std::get<float>(variant);
377 else if (std::holds_alternative<int>(variant))
378 return static_cast<float>(std::get<int>(variant));
379 else if (std::holds_alternative<bool>(variant))
380 return static_cast<float>(std::get<bool>(variant));
382 B2FATAL(
"Unsupported variable type");
388 if (
m_tree->GetEntry(event, 0) == 0) {
389 B2ERROR(
"Error during loading entry from chain");
406 if (branchName.empty()) {
407 B2INFO(
"No TBranch name given for weights. Using 1s as default weights.");
409 std::vector<float> values(nentries, 1.);
412 if (branchName ==
"__weight__") {
414 B2INFO(
"No default weight branch with name __weight__ found. Using 1s as default weights.");
416 std::vector<float> values(nentries, 1.);
420 std::string typeLabel =
"weights";
427 B2ERROR(
"Feature index " << iFeature <<
" is out of bounds of given number of features: "
431 std::string typeLabel =
"features";
438 B2ERROR(
"Spectator index " << iSpectator <<
" is out of bounds of given number of spectators: "
443 std::string typeLabel =
"spectators";
456 if (std::holds_alternative<double>(memberVariableTarget))
457 return getVectorFromTTree(variableType, branchName, std::get<double>(memberVariableTarget));
458 else if (std::holds_alternative<float>(memberVariableTarget))
459 return getVectorFromTTree(variableType, branchName, std::get<float>(memberVariableTarget));
460 else if (std::holds_alternative<int>(memberVariableTarget))
461 return getVectorFromTTree(variableType, branchName, std::get<int>(memberVariableTarget));
462 else if (std::holds_alternative<bool>(memberVariableTarget))
463 return getVectorFromTTree(variableType, branchName, std::get<bool>(memberVariableTarget));
465 B2FATAL(
"Input type of " << variableType <<
" variable " << branchName <<
" is not supported");
470 T& memberVariableTarget)
473 std::vector<float> values(nentries);
477 auto currentTreeNumber =
m_tree->GetTreeNumber();
478 TBranch* branch =
m_tree->GetBranch(branchName.c_str());
480 B2ERROR(
"TBranch for " + variableType +
" named '" << branchName.c_str() <<
"' does not exist!");
482 branch->SetAddress(&
object);
483 for (
int i = 0; i < nentries; ++i) {
484 auto entry =
m_tree->LoadTree(i);
486 B2ERROR(
"Error during loading root tree from chain, error code: " << entry);
489 if (currentTreeNumber !=
m_tree->GetTreeNumber()) {
490 currentTreeNumber =
m_tree->GetTreeNumber();
491 branch =
m_tree->GetBranch(branchName.c_str());
492 branch->SetAddress(&
object);
494 branch->GetEntry(entry);
498 m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
504 auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
505 return branch !=
nullptr;
513 if (not variableName.empty()) {
515 m_tree->SetBranchStatus(variableName.c_str(),
true);
516 m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
522 B2ERROR(
"Couldn't find given " << variableType <<
" variable named " << variableName <<
523 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
524 throw std::runtime_error(
"Couldn't find given " + variableType +
" variable named " + variableName +
525 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
534 if (std::holds_alternative<double>(varVariantTarget))
536 else if (std::holds_alternative<float>(varVariantTarget))
538 else if (std::holds_alternative<int>(varVariantTarget))
540 else if (std::holds_alternative<bool>(varVariantTarget))
543 B2FATAL(
"Variable type for branch " << variableName <<
" not supported!");
550 for (
unsigned int i = 0; i < variableNames.size(); ++i)
556 std::vector<RootDatasetVarVariant>& varVariantTargets)
558 for (
unsigned int i = 0; i < variableNames.size(); ++i) {
566 m_tree->SetBranchStatus(
"*",
false);
571 m_tree->SetBranchStatus(
"__weight__",
true);
572 std::string typeLabel_weight =
"weight";
573 std::string weight_string =
"__weight__";
579 std::string typeLabel_weight =
"weight";
584 std::string typeLabel_target =
"target";
586 std::string typeLabel_feature =
"feature";
588 std::string typeLabel_spectator =
"spectator";
595 if (type ==
"Double_t")
596 varVariantTarget = 0.0;
597 else if (type ==
"Float_t")
598 varVariantTarget = 0.0f;
599 else if (type ==
"Int_t")
600 varVariantTarget = 0;
601 else if (type ==
"Bool_t")
602 varVariantTarget =
false;
604 B2FATAL(
"Unknown root input type: " << type);
605 throw std::runtime_error(
"Unknown root input type: " + type);
615 TBranch* branch =
m_tree->GetBranch(branch_name.c_str());
616 TLeaf* leaf = branch->GetLeaf(branch_name.c_str());
617 std::string type_name = leaf->GetTypeName();
620 TBranch* branch =
m_tree->GetBranch(compatible_branch_name.c_str());
621 TLeaf* leaf = branch->GetLeaf(compatible_branch_name.c_str());
622 std::string type_name = leaf->GetTypeName();
647 B2INFO(
"No weight variable provided. The weight will be set to 1.");
651 m_tree->SetBranchStatus(
"__weight__",
true);
654 B2INFO(
"Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
655 "weight variable to an empty string if you don't need it.");
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual std::vector< bool > getSignals()
Returns all is Signals.
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
virtual std::vector< float > getTargets()
Returns all targets.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights()
Returns all weights.
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
bool m_isSignal
Defines if the currently loaded event is signal or background.
float m_weight
Contains the weight of the currently loaded event.
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
float m_target
Contains the target value of the currently loaded event.
General options which are shared by all MVA trainings.
std::vector< float > m_weights
weight vector
std::vector< std::vector< float > > m_matrix
Feature matrix.
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
std::vector< float > m_targets
target vector
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float > > &input, const std::vector< std::vector< float > > &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
void setScalarVariableAddress(const std::string &variableType, const std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
void setScalarVariableAddressVariant(const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
sets the branch address for a scalar variable to a given target
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
void initialiseVarVariantForBranch(const std::string, RootDatasetVarVariant &)
Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correc...
TChain * m_tree
Pointer to the TChain containing the data.
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
void setVectorVariableAddressVariant(const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
sets the branch address for a vector of VarVariant to a given target
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
std::vector< float > getVectorFromTTree(const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
void initialiseVarVariantType(const std::string, RootDatasetVarVariant &)
Initialises the VarVariant.
std::vector< RootDatasetVarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
RootDatasetVarVariant m_target_variant
Contains the target value of the currently loaded event.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
std::vector< RootDatasetVarVariant > m_input_variant
Contains all feature values of the currently loaded event.
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
void setRootInputType()
Tries to infer the data-type of the spectator and feature variables in a root file.
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
void setVectorVariableAddress(const std::string &variableType, const std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
virtual ~ROOTDataset()
Virtual destructor.
float castVarVariantToFloat(RootDatasetVarVariant &) const
Casts a VarVariant which can contain <double,int,bool,float> to float.
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
std::variant< double, float, int, bool > RootDatasetVarVariant
Typedef for variable types supported by the mva ROOTDataset, can be one of double,...
std::vector< float > getVectorFromTTreeVariant(const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
Returns all values for a specified variableType and branchName.
RootDatasetVarVariant m_weight_variant
Contains the weight of the currently loaded event.
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
Dataset & m_dataset
Reference to the wrapped dataset.
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
Abstract base class for different kinds of events.