Belle II Software development
ROOTDataset Class Reference

Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the mva methods. More...

#include <Dataset.h>

Inheritance diagram for ROOTDataset:
Dataset

Public Member Functions

 ROOTDataset (const GeneralOptions &_general_options)
 Creates a new ROOTDataset.
 
virtual unsigned int getNumberOfFeatures () const override
 Returns the number of features in this dataset.
 
virtual unsigned int getNumberOfSpectators () const override
 Returns the number of features in this dataset.
 
virtual unsigned int getNumberOfEvents () const override
 Returns the number of events in this dataset.
 
virtual void loadEvent (unsigned int event) override
 Load the event number iEvent from the TTree.
 
virtual std::vector< float > getFeature (unsigned int iFeature) override
 Returns all values of one feature in a std::vector<float>
 
virtual std::vector< float > getWeights () override
 Returns all values of of the weights in a std::vector<float>
 
virtual std::vector< float > getSpectator (unsigned int iSpectator) override
 Returns all values of one spectator in a std::vector<float>
 
virtual ~ROOTDataset ()
 Virtual destructor.
 
virtual float getSignalFraction ()
 Returns the signal fraction of the whole sample.
 
virtual unsigned int getFeatureIndex (const std::string &feature)
 Return index of feature with the given name.
 
virtual unsigned int getSpectatorIndex (const std::string &spectator)
 Return index of spectator with the given name.
 
virtual std::vector< float > getTargets ()
 Returns all targets.
 
virtual std::vector< bool > getSignals ()
 Returns all is Signals.
 

Public Attributes

GeneralOptions m_general_options
 GeneralOptions passed to this dataset.
 
std::vector< float > m_input
 Contains all feature values of the currently loaded event.
 
std::vector< float > m_spectators
 Contains all spectators values of the currently loaded event.
 
float m_weight
 Contains the weight of the currently loaded event.
 
float m_target
 Contains the target value of the currently loaded event.
 
bool m_isSignal
 Defines if the currently loaded event is signal or background.
 

Protected Types

typedef std::variant< double, float, int, bool > RootDatasetVarVariant
 Typedef for variable types supported by the mva ROOTDataset, can be one of double, float, int or bool in std::variant.
 

Protected Attributes

TChain * m_tree = nullptr
 Pointer to the TChain containing the data.
 
std::vector< RootDatasetVarVariantm_input_variant
 Contains all feature values of the currently loaded event.
 
std::vector< RootDatasetVarVariantm_spectators_variant
 Contains all spectators values of the currently loaded event.
 
RootDatasetVarVariant m_weight_variant
 Contains the weight of the currently loaded event.
 
RootDatasetVarVariant m_target_variant
 Contains the target value of the currently loaded event.
 

Private Member Functions

template<class T >
std::vector< float > getVectorFromTTree (const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
 Returns all values for a specified variableType and branchName.
 
std::vector< float > getVectorFromTTreeVariant (const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
 Returns all values for a specified variableType and branchName.
 
void setRootInputType ()
 Tries to infer the data-type of the spectator and feature variables in a root file.
 
template<class T >
void setScalarVariableAddress (const std::string &variableType, const std::string &variableName, T &variableTarget)
 sets the branch address for a scalar variable to a given target
 
void setScalarVariableAddressVariant (const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
 sets the branch address for a scalar variable to a given target
 
template<class T >
void setVectorVariableAddress (const std::string &variableType, const std::vector< std::string > &variableName, T &variableTargets)
 sets the branch address for a vector variable to a given target
 
void setVectorVariableAddressVariant (const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
 sets the branch address for a vector of VarVariant to a given target
 
void setTargetRootInputType ()
 Determines the data type of the target variable and sets it to m_target_data_type.
 
void setBranchAddresses ()
 Sets the branch addresses of all features, weight and target again.
 
bool checkForBranch (TTree *, const std::string &) const
 Checks if the given branchname exists in the TTree.
 
float castVarVariantToFloat (RootDatasetVarVariant &) const
 Casts a VarVariant which can contain <double,int,bool,float> to float.
 
void initialiseVarVariantType (const std::string, RootDatasetVarVariant &)
 Initialises the VarVariant.
 
void initialiseVarVariantForBranch (const std::string, RootDatasetVarVariant &)
 Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correct type.
 

Detailed Description

Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the mva methods.

Definition at line 349 of file Dataset.h.

Member Typedef Documentation

◆ RootDatasetVarVariant

typedef std::variant<double, float, int, bool> RootDatasetVarVariant
protected

Typedef for variable types supported by the mva ROOTDataset, can be one of double, float, int or bool in std::variant.

Definition at line 406 of file Dataset.h.

Constructor & Destructor Documentation

◆ ROOTDataset()

ROOTDataset ( const GeneralOptions _general_options)
explicit

Creates a new ROOTDataset.

Parameters
_general_optionsdefines the rootfile, treename, branches, ...

Definition at line 316 of file Dataset.cc.

316 : Dataset(general_options)
317 {
320 m_weight_variant = 1.0f;
321 m_target_variant = 0.0f;
322
323 for (const auto& variable : general_options.m_variables)
324 for (const auto& spectator : general_options.m_spectators)
325 if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
326 B2ERROR("Interface doesn't support variable more then one time in either spectators, variables or target variable");
327 throw std::runtime_error("Interface doesn't support variable more then one time in either spectators, variables or target variable");
328 }
329
330 std::vector<std::string> filenames;
331 for (const auto& filename : m_general_options.m_datafiles) {
332 if (std::filesystem::exists(filename)) {
333 filenames.push_back(filename);
334 } else {
336 filenames.insert(filenames.end(), temp.begin(), temp.end());
337 }
338 }
339 if (filenames.empty()) {
340 B2ERROR("Found no valid filenames in GeneralOptions");
341 throw std::runtime_error("Found no valid filenames in GeneralOptions");
342 }
343
344 //Open TFile
345 TDirectory* dir = gDirectory;
346 for (const auto& filename : filenames) {
347 if (not std::filesystem::exists(filename)) {
348 B2ERROR("Error given ROOT file does not exist " << filename);
349 throw std::runtime_error("Error during open of ROOT file named " + filename);
350 }
351
352 TFile* f = TFile::Open(filename.c_str(), "READ");
353 if (!f or f->IsZombie() or not f->IsOpen()) {
354 B2ERROR("Error during open of ROOT file named " << filename);
355 throw std::runtime_error("Error during open of ROOT file named " + filename);
356 }
357 delete f;
358 }
359 dir->cd();
360
361 m_tree = new TChain(m_general_options.m_treename.c_str());
362 for (const auto& filename : filenames) {
363 //nentries = -1 forces AddFile() to read headers
364 if (!m_tree->AddFile(filename.c_str(), -1)) {
365 B2ERROR("Error during open of ROOT file named " << filename << " cannot retrieve tree named " <<
367 throw std::runtime_error("Error during open of ROOT file named " + filename + " cannot retrieve tree named " +
369 }
370 }
373 }
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
Definition: Dataset.h:122
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
Definition: Dataset.cc:26
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
Definition: Options.h:84
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_treename
Name of the TTree inside the datafile containing the training data.
Definition: Options.h:85
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
Definition: Dataset.cc:568
TChain * m_tree
Pointer to the TChain containing the data.
Definition: Dataset.h:408
std::vector< RootDatasetVarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:411
RootDatasetVarVariant m_target_variant
Contains the target value of the currently loaded event.
Definition: Dataset.h:413
std::vector< RootDatasetVarVariant > m_input_variant
Contains all feature values of the currently loaded event.
Definition: Dataset.h:409
void setRootInputType()
Tries to infer the data-type of the spectator and feature variables in a root file.
Definition: Dataset.cc:632
RootDatasetVarVariant m_weight_variant
Contains the weight of the currently loaded event.
Definition: Dataset.h:412
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.

◆ ~ROOTDataset()

~ROOTDataset ( )
virtual

Virtual destructor.

Definition at line 452 of file Dataset.cc.

453 {
454 delete m_tree;
455 m_tree = nullptr;
456 }

Member Function Documentation

◆ castVarVariantToFloat()

float castVarVariantToFloat ( RootDatasetVarVariant variant) const
private

Casts a VarVariant which can contain <double,int,bool,float> to float.

Parameters
variantthe VarVariant to cast

Definition at line 376 of file Dataset.cc.

377 {
378 if (std::holds_alternative<double>(variant))
379 return static_cast<float>(std::get<double>(variant));
380 else if (std::holds_alternative<float>(variant))
381 return std::get<float>(variant);
382 else if (std::holds_alternative<int>(variant))
383 return static_cast<float>(std::get<int>(variant));
384 else if (std::holds_alternative<bool>(variant))
385 return static_cast<float>(std::get<bool>(variant));
386 else {
387 B2FATAL("Unsupported variable type");
388 }
389 }

◆ checkForBranch()

bool checkForBranch ( TTree *  tree,
const std::string &  branchname 
) const
private

Checks if the given branchname exists in the TTree.

Parameters
tree
branchname

Definition at line 507 of file Dataset.cc.

508 {
509 auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
510 return branch != nullptr;
511
512 }

◆ getFeature()

std::vector< float > getFeature ( unsigned int  iFeature)
overridevirtual

Returns all values of one feature in a std::vector<float>

Parameters
iFeaturethe position of the feature to return

Reimplemented from Dataset.

Definition at line 429 of file Dataset.cc.

430 {
431 if (iFeature >= getNumberOfFeatures()) {
432 B2ERROR("Feature index " << iFeature << " is out of bounds of given number of features: "
434 }
436 std::string typeLabel = "features";
437 return getVectorFromTTreeVariant(typeLabel, branchName, m_input_variant[iFeature]);
438 }
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
Definition: Dataset.h:361
std::vector< float > getVectorFromTTreeVariant(const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
Returns all values for a specified variableType and branchName.
Definition: Dataset.cc:458
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.

◆ getFeatureIndex()

unsigned int getFeatureIndex ( const std::string &  feature)
virtualinherited

Return index of feature with the given name.

Parameters
featurename of the feature

Definition at line 50 of file Dataset.cc.

51 {
52
53 auto it = std::find(m_general_options.m_variables.begin(), m_general_options.m_variables.end(), feature);
54 if (it == m_general_options.m_variables.end()) {
55 B2ERROR("Unknown feature named " << feature);
56 return 0;
57 }
58 return std::distance(m_general_options.m_variables.begin(), it);
59
60 }

◆ getNumberOfEvents()

virtual unsigned int getNumberOfEvents ( ) const
inlineoverridevirtual

Returns the number of events in this dataset.

Implements Dataset.

Definition at line 371 of file Dataset.h.

372 {
374 }
unsigned int m_max_events
Maximum number of events to process, 0 means all.
Definition: Options.h:92

◆ getNumberOfFeatures()

virtual unsigned int getNumberOfFeatures ( ) const
inlineoverridevirtual

Returns the number of features in this dataset.

Implements Dataset.

Definition at line 361 of file Dataset.h.

361{ return m_input.size(); }
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Definition: Dataset.h:123

◆ getNumberOfSpectators()

virtual unsigned int getNumberOfSpectators ( ) const
inlineoverridevirtual

Returns the number of features in this dataset.

Implements Dataset.

Definition at line 366 of file Dataset.h.

366{ return m_spectators.size(); }
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:124

◆ getSignalFraction()

float getSignalFraction ( )
virtualinherited

Returns the signal fraction of the whole sample.

Reimplemented in SPlotDataset.

Definition at line 35 of file Dataset.cc.

36 {
37
38 double signal_weight_sum = 0;
39 double weight_sum = 0;
40 for (unsigned int i = 0; i < getNumberOfEvents(); ++i) {
41 loadEvent(i);
42 weight_sum += m_weight;
43 if (m_isSignal)
44 signal_weight_sum += m_weight;
45 }
46 return signal_weight_sum / weight_sum;
47
48 }
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
bool m_isSignal
Defines if the currently loaded event is signal or background.
Definition: Dataset.h:127
float m_weight
Contains the weight of the currently loaded event.
Definition: Dataset.h:125

◆ getSignals()

std::vector< bool > getSignals ( )
virtualinherited

Returns all is Signals.

Reimplemented in ReweightingDataset.

Definition at line 122 of file Dataset.cc.

123 {
124
125 std::vector<bool> result(getNumberOfEvents());
126 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
127 loadEvent(iEvent);
128 result[iEvent] = m_isSignal;
129 }
130 return result;
131
132 }

◆ getSpectator()

std::vector< float > getSpectator ( unsigned int  iSpectator)
overridevirtual

Returns all values of one spectator in a std::vector<float>

Parameters
iSpectatorthe position of the spectator to return

Reimplemented from Dataset.

Definition at line 440 of file Dataset.cc.

441 {
442 if (iSpectator >= getNumberOfSpectators()) {
443 B2ERROR("Spectator index " << iSpectator << " is out of bounds of given number of spectators: "
445 }
446
448 std::string typeLabel = "spectators";
449 return getVectorFromTTreeVariant(typeLabel, branchName, m_spectators_variant[iSpectator]);
450 }
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
Definition: Dataset.h:366

◆ getSpectatorIndex()

unsigned int getSpectatorIndex ( const std::string &  spectator)
virtualinherited

Return index of spectator with the given name.

Parameters
spectatorname of the spectator

Definition at line 62 of file Dataset.cc.

63 {
64
65 auto it = std::find(m_general_options.m_spectators.begin(), m_general_options.m_spectators.end(), spectator);
66 if (it == m_general_options.m_spectators.end()) {
67 B2ERROR("Unknown spectator named " << spectator);
68 return 0;
69 }
70 return std::distance(m_general_options.m_spectators.begin(), it);
71
72 }

◆ getTargets()

std::vector< float > getTargets ( )
virtualinherited

Returns all targets.

Reimplemented in RegressionDataSet, and ReweightingDataset.

Definition at line 110 of file Dataset.cc.

111 {
112
113 std::vector<float> result(getNumberOfEvents());
114 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
115 loadEvent(iEvent);
116 result[iEvent] = m_target;
117 }
118 return result;
119
120 }
float m_target
Contains the target value of the currently loaded event.
Definition: Dataset.h:126

◆ getVectorFromTTree()

std::vector< float > getVectorFromTTree ( const std::string &  variableType,
const std::string &  branchName,
T &  memberVariableTarget 
)
private

Returns all values for a specified variableType and branchName.

The values are read from a root file. The type is inferred from the given memberVariableTarget name.

Template Parameters
Ttype memberVariable of this class which has to be updated (float, double)
Parameters
variableTypedefines {feature, weights, spectator, target}
branchNamename of the branch to read
memberVariableTargetvariable the branch address from the root file is set to
Returns
filled vector from a branch, converted to float

Definition at line 474 of file Dataset.cc.

476 {
477 int nentries = getNumberOfEvents();
478 std::vector<float> values(nentries);
479
480 // Float or Double to be filled
481 T object;
482 auto currentTreeNumber = m_tree->GetTreeNumber();
483 TBranch* branch = m_tree->GetBranch(branchName.c_str());
484 if (not branch) {
485 B2ERROR("TBranch for " + variableType + " named '" << branchName.c_str() << "' does not exist!");
486 }
487 branch->SetAddress(&object);
488 for (int i = 0; i < nentries; ++i) {
489 auto entry = m_tree->LoadTree(i);
490 if (entry < 0) {
491 B2ERROR("Error during loading root tree from chain, error code: " << entry);
492 }
493 // if current tree changed we have to update the branch
494 if (currentTreeNumber != m_tree->GetTreeNumber()) {
495 currentTreeNumber = m_tree->GetTreeNumber();
496 branch = m_tree->GetBranch(branchName.c_str());
497 branch->SetAddress(&object);
498 }
499 branch->GetEntry(entry);
500 values[i] = object;
501 }
502 // Reset branch to correct input address, just to be sure
503 m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
504 return values;
505 }
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
Definition: Dataset.h:371

◆ getVectorFromTTreeVariant()

std::vector< float > getVectorFromTTreeVariant ( const std::string &  variableType,
const std::string &  branchName,
RootDatasetVarVariant memberVariableTarget 
)
private

Returns all values for a specified variableType and branchName.

The values are read from a root file. The type is inferred from the given memberVariableTarget name.

Parameters
variableTypedefines {feature, weights, spectator, target}
branchNamename of the branch to read
memberVariableTargetvariable the branch address from the root file is set to
Returns
filled vector from a branch, converted to float

Definition at line 458 of file Dataset.cc.

460 {
461 if (std::holds_alternative<double>(memberVariableTarget))
462 return getVectorFromTTree(variableType, branchName, std::get<double>(memberVariableTarget));
463 else if (std::holds_alternative<float>(memberVariableTarget))
464 return getVectorFromTTree(variableType, branchName, std::get<float>(memberVariableTarget));
465 else if (std::holds_alternative<int>(memberVariableTarget))
466 return getVectorFromTTree(variableType, branchName, std::get<int>(memberVariableTarget));
467 else if (std::holds_alternative<bool>(memberVariableTarget))
468 return getVectorFromTTree(variableType, branchName, std::get<bool>(memberVariableTarget));
469 else
470 B2FATAL("Input type of " << variableType << " variable " << branchName << " is not supported");
471 }
std::vector< float > getVectorFromTTree(const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
Definition: Dataset.cc:474

◆ getWeights()

std::vector< float > getWeights ( )
overridevirtual

Returns all values of of the weights in a std::vector<float>

Reimplemented from Dataset.

Definition at line 408 of file Dataset.cc.

409 {
411 if (branchName.empty()) {
412 B2INFO("No TBranch name given for weights. Using 1s as default weights.");
413 int nentries = getNumberOfEvents();
414 std::vector<float> values(nentries, 1.);
415 return values;
416 }
417 if (branchName == "__weight__") {
418 if (!checkForBranch(m_tree, "__weight__")) {
419 B2INFO("No default weight branch with name __weight__ found. Using 1s as default weights.");
420 int nentries = getNumberOfEvents();
421 std::vector<float> values(nentries, 1.);
422 return values;
423 }
424 }
425 std::string typeLabel = "weights";
426 return getVectorFromTTreeVariant(typeLabel, branchName, m_weight_variant);
427 }
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
Definition: Dataset.cc:507

◆ initialiseVarVariantForBranch()

void initialiseVarVariantForBranch ( const std::string  branch_name,
RootDatasetVarVariant varVariantTarget 
)
private

Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correct type.

Parameters
branch_namebranch name in the datafile
varVariantTargetvariant to initialise

Definition at line 615 of file Dataset.cc.

616 {
617 std::string compatible_branch_name = Belle2::MakeROOTCompatible::makeROOTCompatible(branch_name);
618 // try the branch as is first then fall back to root safe name.
619 if (checkForBranch(m_tree, branch_name.c_str())) {
620 TBranch* branch = m_tree->GetBranch(branch_name.c_str());
621 TLeaf* leaf = branch->GetLeaf(branch_name.c_str());
622 std::string type_name = leaf->GetTypeName();
623 initialiseVarVariantType(type_name, varVariantTarget);
624 } else if (checkForBranch(m_tree, compatible_branch_name)) {
625 TBranch* branch = m_tree->GetBranch(compatible_branch_name.c_str());
626 TLeaf* leaf = branch->GetLeaf(compatible_branch_name.c_str());
627 std::string type_name = leaf->GetTypeName();
628 initialiseVarVariantType(type_name, varVariantTarget);
629 }
630 }
void initialiseVarVariantType(const std::string, RootDatasetVarVariant &)
Initialises the VarVariant.
Definition: Dataset.cc:598

◆ initialiseVarVariantType()

void initialiseVarVariantType ( const std::string  type,
RootDatasetVarVariant varVariantTarget 
)
private

Initialises the VarVariant.

Parameters
typedefines which alternative to use for the variant {Double_t, Float_t, Int_t, Bool_t}
varVariantTargetvariant to initialise.

Definition at line 598 of file Dataset.cc.

599 {
600 if (type == "Double_t")
601 varVariantTarget = 0.0;
602 else if (type == "Float_t")
603 varVariantTarget = 0.0f;
604 else if (type == "Int_t")
605 varVariantTarget = 0;
606 else if (type == "Bool_t")
607 varVariantTarget = false;
608 else {
609 B2FATAL("Unknown root input type: " << type);
610 throw std::runtime_error("Unknown root input type: " + type);
611 }
612 }

◆ loadEvent()

void loadEvent ( unsigned int  event)
overridevirtual

Load the event number iEvent from the TTree.

Parameters
eventevent number to load

Implements Dataset.

Definition at line 391 of file Dataset.cc.

392 {
393 if (m_tree->GetEntry(event, 0) == 0) {
394 B2ERROR("Error during loading entry from chain");
395 }
396
397 for (unsigned int i = 0; i < m_input_variant.size(); i++) {
399 }
400 for (unsigned int i = 0; i < m_spectators_variant.size(); i++) {
402 }
406 }
int m_signal_class
Signal class which is used as signal in a classification problem.
Definition: Options.h:88
float castVarVariantToFloat(RootDatasetVarVariant &) const
Casts a VarVariant which can contain <double,int,bool,float> to float.
Definition: Dataset.cc:376

◆ setBranchAddresses()

void setBranchAddresses ( )
private

Sets the branch addresses of all features, weight and target again.

Definition at line 568 of file Dataset.cc.

569 {
570 // Deactivate all branches by default
571 m_tree->SetBranchStatus("*", false);
572
574 if (m_general_options.m_weight_variable == "__weight__") {
575 if (checkForBranch(m_tree, "__weight__")) {
576 m_tree->SetBranchStatus("__weight__", true);
577 std::string typeLabel_weight = "weight";
578 std::string weight_string = "__weight__";
579 setScalarVariableAddressVariant(typeLabel_weight, weight_string, m_weight_variant);
580 } else {
581 m_weight_variant = 1.0f;
582 }
583 } else {
584 std::string typeLabel_weight = "weight";
586 }
587 }
588
589 std::string typeLabel_target = "target";
591 std::string typeLabel_feature = "feature";
593 std::string typeLabel_spectator = "spectator";
595 }
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
void setScalarVariableAddressVariant(const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
sets the branch address for a scalar variable to a given target
Definition: Dataset.cc:536
void setVectorVariableAddressVariant(const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
sets the branch address for a vector of VarVariant to a given target
Definition: Dataset.cc:560

◆ setRootInputType()

void setRootInputType ( )
private

Tries to infer the data-type of the spectator and feature variables in a root file.

Definition at line 632 of file Dataset.cc.

633 {
634 // set target variable
636
637 // set feature variables
638 for (unsigned int i = 0; i < m_general_options.m_variables.size(); i++) {
639 auto variable = m_general_options.m_variables[i];
641 }
642
643 // set spectator variables
644 for (unsigned int i = 0; i < m_general_options.m_spectators.size(); i++) {
645 auto variable = m_general_options.m_spectators[i];
647 }
648
649 // set weight variable - bit more tricky as we allow it to not be set or to not be present.
651 m_weight_variant = 1.0f;
652 B2INFO("No weight variable provided. The weight will be set to 1.");
653 } else {
654 if (m_general_options.m_weight_variable == "__weight__") {
655 if (checkForBranch(m_tree, "__weight__")) {
656 m_tree->SetBranchStatus("__weight__", true);
658 } else {
659 B2INFO("Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
660 "weight variable to an empty string if you don't need it.");
661 m_weight_variant = 1.0f;
662 }
663 } else {
665 }
666 }
667 }
void initialiseVarVariantForBranch(const std::string, RootDatasetVarVariant &)
Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correc...
Definition: Dataset.cc:615

◆ setScalarVariableAddress()

void setScalarVariableAddress ( const std::string &  variableType,
const std::string &  variableName,
T &  variableTarget 
)
private

sets the branch address for a scalar variable to a given target

Template Parameters
Ttarget type (float, double)
Parameters
variableTypedefines {feature, weights, spectator, target}
variableNamename of the variable, usually defined in general_options
variableTargetvariable, the address is set to

Definition at line 515 of file Dataset.cc.

517 {
518 if (not variableName.empty()) {
519 if (checkForBranch(m_tree, variableName)) {
520 m_tree->SetBranchStatus(variableName.c_str(), true);
521 m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
522 } else {
524 m_tree->SetBranchStatus(Belle2::MakeROOTCompatible::makeROOTCompatible(variableName).c_str(), true);
525 m_tree->SetBranchAddress(Belle2::MakeROOTCompatible::makeROOTCompatible(variableName).c_str(), &variableTarget);
526 } else {
527 B2ERROR("Couldn't find given " << variableType << " variable named " << variableName <<
528 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
529 throw std::runtime_error("Couldn't find given " + variableType + " variable named " + variableName +
530 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
531 }
532 }
533 }
534 }

◆ setScalarVariableAddressVariant()

void setScalarVariableAddressVariant ( const std::string &  variableType,
const std::string &  variableName,
RootDatasetVarVariant variableTarget 
)
private

sets the branch address for a scalar variable to a given target

Parameters
variableTypedefines {feature, weights, spectator, target}
variableNamename of the variable, usually defined in general_options
variableTargetvariable, the address is set to

Definition at line 536 of file Dataset.cc.

538 {
539 if (std::holds_alternative<double>(varVariantTarget))
540 setScalarVariableAddress(variableType, variableName, std::get<double>(varVariantTarget));
541 else if (std::holds_alternative<float>(varVariantTarget))
542 setScalarVariableAddress(variableType, variableName, std::get<float>(varVariantTarget));
543 else if (std::holds_alternative<int>(varVariantTarget))
544 setScalarVariableAddress(variableType, variableName, std::get<int>(varVariantTarget));
545 else if (std::holds_alternative<bool>(varVariantTarget))
546 setScalarVariableAddress(variableType, variableName, std::get<bool>(varVariantTarget));
547 else
548 B2FATAL("Variable type for branch " << variableName << " not supported!");
549 }
void setScalarVariableAddress(const std::string &variableType, const std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
Definition: Dataset.cc:515

◆ setVectorVariableAddress()

void setVectorVariableAddress ( const std::string &  variableType,
const std::vector< std::string > &  variableName,
T &  variableTargets 
)
private

sets the branch address for a vector variable to a given target

Template Parameters
Ttarget type (std::vector<float>, std::vector<double>)
Parameters
variableTypedefines {feature, weights, spectator, target}
variableNamenames of the variable, usually defined in general_options
variableTargetsvariables, the address is set to

Definition at line 552 of file Dataset.cc.

554 {
555 for (unsigned int i = 0; i < variableNames.size(); ++i)
556 ROOTDataset::setScalarVariableAddress(variableType, variableNames[i], variableTargets[i]);
557 }

◆ setVectorVariableAddressVariant()

void setVectorVariableAddressVariant ( const std::string &  variableType,
const std::vector< std::string > &  variableName,
std::vector< RootDatasetVarVariant > &  varVariantTargets 
)
private

sets the branch address for a vector of VarVariant to a given target

Parameters
variableTypedefines {feature, weights, spectator, target}
variableNamenames of the variable, usually defined in general_options
varVariantTargetsvariables, the address is set to

Definition at line 560 of file Dataset.cc.

562 {
563 for (unsigned int i = 0; i < variableNames.size(); ++i) {
564 ROOTDataset::setScalarVariableAddressVariant(variableType, variableNames[i], varVariantTargets[i]);
565 }
566 }

Member Data Documentation

◆ m_general_options

GeneralOptions m_general_options
inherited

GeneralOptions passed to this dataset.

Definition at line 122 of file Dataset.h.

◆ m_input

std::vector<float> m_input
inherited

Contains all feature values of the currently loaded event.

Definition at line 123 of file Dataset.h.

◆ m_input_variant

std::vector<RootDatasetVarVariant> m_input_variant
protected

Contains all feature values of the currently loaded event.

Definition at line 409 of file Dataset.h.

◆ m_isSignal

bool m_isSignal
inherited

Defines if the currently loaded event is signal or background.

Definition at line 127 of file Dataset.h.

◆ m_spectators

std::vector<float> m_spectators
inherited

Contains all spectators values of the currently loaded event.

Definition at line 124 of file Dataset.h.

◆ m_spectators_variant

std::vector<RootDatasetVarVariant> m_spectators_variant
protected

Contains all spectators values of the currently loaded event.

Definition at line 411 of file Dataset.h.

◆ m_target

float m_target
inherited

Contains the target value of the currently loaded event.

Definition at line 126 of file Dataset.h.

◆ m_target_variant

RootDatasetVarVariant m_target_variant
protected

Contains the target value of the currently loaded event.

Definition at line 413 of file Dataset.h.

◆ m_tree

TChain* m_tree = nullptr
protected

Pointer to the TChain containing the data.

Definition at line 408 of file Dataset.h.

◆ m_weight

float m_weight
inherited

Contains the weight of the currently loaded event.

Definition at line 125 of file Dataset.h.

◆ m_weight_variant

RootDatasetVarVariant m_weight_variant
protected

Contains the weight of the currently loaded event.

Definition at line 412 of file Dataset.h.


The documentation for this class was generated from the following files: