Belle II Software development
ROOTDataset Class Reference

Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the mva methods. More...

#include <Dataset.h>

Inheritance diagram for ROOTDataset:
Dataset

Public Member Functions

 ROOTDataset (const GeneralOptions &_general_options)
 Creates a new ROOTDataset.
 
virtual unsigned int getNumberOfFeatures () const override
 Returns the number of features in this dataset.
 
virtual unsigned int getNumberOfSpectators () const override
 Returns the number of features in this dataset.
 
virtual unsigned int getNumberOfEvents () const override
 Returns the number of events in this dataset.
 
virtual void loadEvent (unsigned int event) override
 Load the event number iEvent from the TTree.
 
virtual std::vector< float > getFeature (unsigned int iFeature) override
 Returns all values of one feature in a std::vector<float>
 
virtual std::vector< float > getWeights () override
 Returns all values of of the weights in a std::vector<float>
 
virtual std::vector< float > getSpectator (unsigned int iSpectator) override
 Returns all values of one spectator in a std::vector<float>
 
virtual ~ROOTDataset ()
 Virtual destructor.
 
virtual float getSignalFraction ()
 Returns the signal fraction of the whole sample.
 
virtual unsigned int getFeatureIndex (const std::string &feature)
 Return index of feature with the given name.
 
virtual unsigned int getSpectatorIndex (const std::string &spectator)
 Return index of spectator with the given name.
 
virtual std::vector< float > getTargets ()
 Returns all targets.
 
virtual std::vector< bool > getSignals ()
 Returns all is Signals.
 

Public Attributes

GeneralOptions m_general_options
 GeneralOptions passed to this dataset.
 
std::vector< float > m_input
 Contains all feature values of the currently loaded event.
 
std::vector< float > m_spectators
 Contains all spectators values of the currently loaded event.
 
float m_weight
 Contains the weight of the currently loaded event.
 
float m_target
 Contains the target value of the currently loaded event.
 
bool m_isSignal
 Defines if the currently loaded event is signal or background.
 

Protected Types

typedef std::variant< double, float, int, bool > RootDatasetVarVariant
 Typedef for variable types supported by the mva ROOTDataset, can be one of double, float, int or bool in std::variant.
 

Protected Attributes

TChain * m_tree = nullptr
 Pointer to the TChain containing the data.
 
std::vector< RootDatasetVarVariantm_input_variant
 Contains all feature values of the currently loaded event.
 
std::vector< RootDatasetVarVariantm_spectators_variant
 Contains all spectators values of the currently loaded event.
 
RootDatasetVarVariant m_weight_variant
 Contains the weight of the currently loaded event.
 
RootDatasetVarVariant m_target_variant
 Contains the target value of the currently loaded event.
 

Private Member Functions

template<class T >
std::vector< float > getVectorFromTTree (const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
 Returns all values for a specified variableType and branchName.
 
std::vector< float > getVectorFromTTreeVariant (const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
 Returns all values for a specified variableType and branchName.
 
void setRootInputType ()
 Tries to infer the data-type of the spectator and feature variables in a root file.
 
template<class T >
void setScalarVariableAddress (const std::string &variableType, const std::string &variableName, T &variableTarget)
 sets the branch address for a scalar variable to a given target
 
void setScalarVariableAddressVariant (const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
 sets the branch address for a scalar variable to a given target
 
template<class T >
void setVectorVariableAddress (const std::string &variableType, const std::vector< std::string > &variableName, T &variableTargets)
 sets the branch address for a vector variable to a given target
 
void setVectorVariableAddressVariant (const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
 sets the branch address for a vector of VarVariant to a given target
 
void setTargetRootInputType ()
 Determines the data type of the target variable and sets it to m_target_data_type.
 
void setBranchAddresses ()
 Sets the branch addresses of all features, weight and target again.
 
bool checkForBranch (TTree *, const std::string &) const
 Checks if the given branchname exists in the TTree.
 
float castVarVariantToFloat (RootDatasetVarVariant &) const
 Casts a VarVariant which can contain <double,int,bool,float> to float.
 
void initialiseVarVariantType (const std::string, RootDatasetVarVariant &)
 Initialises the VarVariant.
 
void initialiseVarVariantForBranch (const std::string, RootDatasetVarVariant &)
 Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correct type.
 

Detailed Description

Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the mva methods.

Definition at line 349 of file Dataset.h.

Member Typedef Documentation

◆ RootDatasetVarVariant

typedef std::variant<double, float, int, bool> RootDatasetVarVariant
protected

Typedef for variable types supported by the mva ROOTDataset, can be one of double, float, int or bool in std::variant.

Definition at line 406 of file Dataset.h.

Constructor & Destructor Documentation

◆ ROOTDataset()

ROOTDataset ( const GeneralOptions _general_options)
explicit

Creates a new ROOTDataset.

Parameters
_general_optionsdefines the rootfile, treename, branches, ...

Definition at line 316 of file Dataset.cc.

316 : Dataset(general_options)
317 {
320 m_weight_variant = 1.0f;
321 m_target_variant = 0.0f;
322
323 for (const auto& variable : general_options.m_variables)
324 for (const auto& spectator : general_options.m_spectators)
325 if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
326 B2ERROR("Interface doesn't support variable more then one time in either spectators, variables or target variable");
327 throw std::runtime_error("Interface doesn't support variable more then one time in either spectators, variables or target variable");
328 }
329
330 std::vector<std::string> filenames;
331 for (const auto& filename : m_general_options.m_datafiles) {
332 if (std::filesystem::exists(filename)) {
333 filenames.push_back(filename);
334 } else {
336 filenames.insert(filenames.end(), temp.begin(), temp.end());
337 }
338 }
339 if (filenames.empty()) {
340 B2ERROR("Found no valid filenames in GeneralOptions");
341 throw std::runtime_error("Found no valid filenames in GeneralOptions");
342 }
343
344 //Open TFile
345 TDirectory* dir = gDirectory;
346 for (const auto& filename : filenames) {
347 TFile* f = TFile::Open(filename.c_str(), "READ");
348 if (!f or f->IsZombie() or not f->IsOpen()) {
349 B2ERROR("Error during open of ROOT file named " << filename);
350 throw std::runtime_error("Error during open of ROOT file named " + filename);
351 }
352 delete f;
353 }
354 dir->cd();
355
356 m_tree = new TChain(m_general_options.m_treename.c_str());
357 for (const auto& filename : filenames) {
358 //nentries = -1 forces AddFile() to read headers
359 if (!m_tree->AddFile(filename.c_str(), -1)) {
360 B2ERROR("Error during open of ROOT file named " << filename << " cannot retrieve tree named " <<
362 throw std::runtime_error("Error during open of ROOT file named " + filename + " cannot retrieve tree named " +
364 }
365 }
368 }
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
Definition: Dataset.h:122
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
Definition: Dataset.cc:26
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
Definition: Options.h:84
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_treename
Name of the TTree inside the datafile containing the training data.
Definition: Options.h:85
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
Definition: Dataset.cc:563
TChain * m_tree
Pointer to the TChain containing the data.
Definition: Dataset.h:408
std::vector< RootDatasetVarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:411
RootDatasetVarVariant m_target_variant
Contains the target value of the currently loaded event.
Definition: Dataset.h:413
std::vector< RootDatasetVarVariant > m_input_variant
Contains all feature values of the currently loaded event.
Definition: Dataset.h:409
void setRootInputType()
Tries to infer the data-type of the spectator and feature variables in a root file.
Definition: Dataset.cc:627
RootDatasetVarVariant m_weight_variant
Contains the weight of the currently loaded event.
Definition: Dataset.h:412
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.

◆ ~ROOTDataset()

~ROOTDataset ( )
virtual

Virtual destructor.

Definition at line 447 of file Dataset.cc.

448 {
449 delete m_tree;
450 m_tree = nullptr;
451 }

Member Function Documentation

◆ castVarVariantToFloat()

float castVarVariantToFloat ( RootDatasetVarVariant variant) const
private

Casts a VarVariant which can contain <double,int,bool,float> to float.

Parameters
variantthe VarVariant to cast

Definition at line 371 of file Dataset.cc.

372 {
373 if (std::holds_alternative<double>(variant))
374 return static_cast<float>(std::get<double>(variant));
375 else if (std::holds_alternative<float>(variant))
376 return std::get<float>(variant);
377 else if (std::holds_alternative<int>(variant))
378 return static_cast<float>(std::get<int>(variant));
379 else if (std::holds_alternative<bool>(variant))
380 return static_cast<float>(std::get<bool>(variant));
381 else {
382 B2FATAL("Unsupported variable type");
383 }
384 }

◆ checkForBranch()

bool checkForBranch ( TTree *  tree,
const std::string &  branchname 
) const
private

Checks if the given branchname exists in the TTree.

Parameters
tree
branchname

Definition at line 502 of file Dataset.cc.

503 {
504 auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
505 return branch != nullptr;
506
507 }

◆ getFeature()

std::vector< float > getFeature ( unsigned int  iFeature)
overridevirtual

Returns all values of one feature in a std::vector<float>

Parameters
iFeaturethe position of the feature to return

Reimplemented from Dataset.

Definition at line 424 of file Dataset.cc.

425 {
426 if (iFeature >= getNumberOfFeatures()) {
427 B2ERROR("Feature index " << iFeature << " is out of bounds of given number of features: "
429 }
431 std::string typeLabel = "features";
432 return getVectorFromTTreeVariant(typeLabel, branchName, m_input_variant[iFeature]);
433 }
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
Definition: Dataset.h:361
std::vector< float > getVectorFromTTreeVariant(const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
Returns all values for a specified variableType and branchName.
Definition: Dataset.cc:453
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.

◆ getFeatureIndex()

unsigned int getFeatureIndex ( const std::string &  feature)
virtualinherited

Return index of feature with the given name.

Parameters
featurename of the feature

Definition at line 50 of file Dataset.cc.

51 {
52
53 auto it = std::find(m_general_options.m_variables.begin(), m_general_options.m_variables.end(), feature);
54 if (it == m_general_options.m_variables.end()) {
55 B2ERROR("Unknown feature named " << feature);
56 return 0;
57 }
58 return std::distance(m_general_options.m_variables.begin(), it);
59
60 }

◆ getNumberOfEvents()

virtual unsigned int getNumberOfEvents ( ) const
inlineoverridevirtual

Returns the number of events in this dataset.

Implements Dataset.

Definition at line 371 of file Dataset.h.

372 {
374 }
unsigned int m_max_events
Maximum number of events to process, 0 means all.
Definition: Options.h:92

◆ getNumberOfFeatures()

virtual unsigned int getNumberOfFeatures ( ) const
inlineoverridevirtual

Returns the number of features in this dataset.

Implements Dataset.

Definition at line 361 of file Dataset.h.

361{ return m_input.size(); }
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Definition: Dataset.h:123

◆ getNumberOfSpectators()

virtual unsigned int getNumberOfSpectators ( ) const
inlineoverridevirtual

Returns the number of features in this dataset.

Implements Dataset.

Definition at line 366 of file Dataset.h.

366{ return m_spectators.size(); }
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:124

◆ getSignalFraction()

float getSignalFraction ( )
virtualinherited

Returns the signal fraction of the whole sample.

Reimplemented in SPlotDataset.

Definition at line 35 of file Dataset.cc.

36 {
37
38 double signal_weight_sum = 0;
39 double weight_sum = 0;
40 for (unsigned int i = 0; i < getNumberOfEvents(); ++i) {
41 loadEvent(i);
42 weight_sum += m_weight;
43 if (m_isSignal)
44 signal_weight_sum += m_weight;
45 }
46 return signal_weight_sum / weight_sum;
47
48 }
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
bool m_isSignal
Defines if the currently loaded event is signal or background.
Definition: Dataset.h:127
float m_weight
Contains the weight of the currently loaded event.
Definition: Dataset.h:125

◆ getSignals()

std::vector< bool > getSignals ( )
virtualinherited

Returns all is Signals.

Reimplemented in ReweightingDataset.

Definition at line 122 of file Dataset.cc.

123 {
124
125 std::vector<bool> result(getNumberOfEvents());
126 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
127 loadEvent(iEvent);
128 result[iEvent] = m_isSignal;
129 }
130 return result;
131
132 }

◆ getSpectator()

std::vector< float > getSpectator ( unsigned int  iSpectator)
overridevirtual

Returns all values of one spectator in a std::vector<float>

Parameters
iSpectatorthe position of the spectator to return

Reimplemented from Dataset.

Definition at line 435 of file Dataset.cc.

436 {
437 if (iSpectator >= getNumberOfSpectators()) {
438 B2ERROR("Spectator index " << iSpectator << " is out of bounds of given number of spectators: "
440 }
441
443 std::string typeLabel = "spectators";
444 return getVectorFromTTreeVariant(typeLabel, branchName, m_spectators_variant[iSpectator]);
445 }
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
Definition: Dataset.h:366

◆ getSpectatorIndex()

unsigned int getSpectatorIndex ( const std::string &  spectator)
virtualinherited

Return index of spectator with the given name.

Parameters
spectatorname of the spectator

Definition at line 62 of file Dataset.cc.

63 {
64
65 auto it = std::find(m_general_options.m_spectators.begin(), m_general_options.m_spectators.end(), spectator);
66 if (it == m_general_options.m_spectators.end()) {
67 B2ERROR("Unknown spectator named " << spectator);
68 return 0;
69 }
70 return std::distance(m_general_options.m_spectators.begin(), it);
71
72 }

◆ getTargets()

std::vector< float > getTargets ( )
virtualinherited

Returns all targets.

Reimplemented in RegressionDataSet, and ReweightingDataset.

Definition at line 110 of file Dataset.cc.

111 {
112
113 std::vector<float> result(getNumberOfEvents());
114 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
115 loadEvent(iEvent);
116 result[iEvent] = m_target;
117 }
118 return result;
119
120 }
float m_target
Contains the target value of the currently loaded event.
Definition: Dataset.h:126

◆ getVectorFromTTree()

std::vector< float > getVectorFromTTree ( const std::string &  variableType,
const std::string &  branchName,
T &  memberVariableTarget 
)
private

Returns all values for a specified variableType and branchName.

The values are read from a root file. The type is inferred from the given memberVariableTarget name.

Template Parameters
Ttype memberVariable of this class which has to be updated (float, double)
Parameters
variableTypedefines {feature, weights, spectator, target}
branchNamename of the branch to read
memberVariableTargetvariable the branch address from the root file is set to
Returns
filled vector from a branch, converted to float

Definition at line 469 of file Dataset.cc.

471 {
472 int nentries = getNumberOfEvents();
473 std::vector<float> values(nentries);
474
475 // Float or Double to be filled
476 T object;
477 auto currentTreeNumber = m_tree->GetTreeNumber();
478 TBranch* branch = m_tree->GetBranch(branchName.c_str());
479 if (not branch) {
480 B2ERROR("TBranch for " + variableType + " named '" << branchName.c_str() << "' does not exist!");
481 }
482 branch->SetAddress(&object);
483 for (int i = 0; i < nentries; ++i) {
484 auto entry = m_tree->LoadTree(i);
485 if (entry < 0) {
486 B2ERROR("Error during loading root tree from chain, error code: " << entry);
487 }
488 // if current tree changed we have to update the branch
489 if (currentTreeNumber != m_tree->GetTreeNumber()) {
490 currentTreeNumber = m_tree->GetTreeNumber();
491 branch = m_tree->GetBranch(branchName.c_str());
492 branch->SetAddress(&object);
493 }
494 branch->GetEntry(entry);
495 values[i] = object;
496 }
497 // Reset branch to correct input address, just to be sure
498 m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
499 return values;
500 }
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
Definition: Dataset.h:371

◆ getVectorFromTTreeVariant()

std::vector< float > getVectorFromTTreeVariant ( const std::string &  variableType,
const std::string &  branchName,
RootDatasetVarVariant memberVariableTarget 
)
private

Returns all values for a specified variableType and branchName.

The values are read from a root file. The type is inferred from the given memberVariableTarget name.

Parameters
variableTypedefines {feature, weights, spectator, target}
branchNamename of the branch to read
memberVariableTargetvariable the branch address from the root file is set to
Returns
filled vector from a branch, converted to float

Definition at line 453 of file Dataset.cc.

455 {
456 if (std::holds_alternative<double>(memberVariableTarget))
457 return getVectorFromTTree(variableType, branchName, std::get<double>(memberVariableTarget));
458 else if (std::holds_alternative<float>(memberVariableTarget))
459 return getVectorFromTTree(variableType, branchName, std::get<float>(memberVariableTarget));
460 else if (std::holds_alternative<int>(memberVariableTarget))
461 return getVectorFromTTree(variableType, branchName, std::get<int>(memberVariableTarget));
462 else if (std::holds_alternative<bool>(memberVariableTarget))
463 return getVectorFromTTree(variableType, branchName, std::get<bool>(memberVariableTarget));
464 else
465 B2FATAL("Input type of " << variableType << " variable " << branchName << " is not supported");
466 }
std::vector< float > getVectorFromTTree(const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
Definition: Dataset.cc:469

◆ getWeights()

std::vector< float > getWeights ( )
overridevirtual

Returns all values of of the weights in a std::vector<float>

Reimplemented from Dataset.

Definition at line 403 of file Dataset.cc.

404 {
406 if (branchName.empty()) {
407 B2INFO("No TBranch name given for weights. Using 1s as default weights.");
408 int nentries = getNumberOfEvents();
409 std::vector<float> values(nentries, 1.);
410 return values;
411 }
412 if (branchName == "__weight__") {
413 if (!checkForBranch(m_tree, "__weight__")) {
414 B2INFO("No default weight branch with name __weight__ found. Using 1s as default weights.");
415 int nentries = getNumberOfEvents();
416 std::vector<float> values(nentries, 1.);
417 return values;
418 }
419 }
420 std::string typeLabel = "weights";
421 return getVectorFromTTreeVariant(typeLabel, branchName, m_weight_variant);
422 }
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
Definition: Dataset.cc:502

◆ initialiseVarVariantForBranch()

void initialiseVarVariantForBranch ( const std::string  branch_name,
RootDatasetVarVariant varVariantTarget 
)
private

Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correct type.

Parameters
branch_namebranch name in the datafile
varVariantTargetvariant to initialise

Definition at line 610 of file Dataset.cc.

611 {
612 std::string compatible_branch_name = Belle2::MakeROOTCompatible::makeROOTCompatible(branch_name);
613 // try the branch as is first then fall back to root safe name.
614 if (checkForBranch(m_tree, branch_name.c_str())) {
615 TBranch* branch = m_tree->GetBranch(branch_name.c_str());
616 TLeaf* leaf = branch->GetLeaf(branch_name.c_str());
617 std::string type_name = leaf->GetTypeName();
618 initialiseVarVariantType(type_name, varVariantTarget);
619 } else if (checkForBranch(m_tree, compatible_branch_name)) {
620 TBranch* branch = m_tree->GetBranch(compatible_branch_name.c_str());
621 TLeaf* leaf = branch->GetLeaf(compatible_branch_name.c_str());
622 std::string type_name = leaf->GetTypeName();
623 initialiseVarVariantType(type_name, varVariantTarget);
624 }
625 }
void initialiseVarVariantType(const std::string, RootDatasetVarVariant &)
Initialises the VarVariant.
Definition: Dataset.cc:593

◆ initialiseVarVariantType()

void initialiseVarVariantType ( const std::string  type,
RootDatasetVarVariant varVariantTarget 
)
private

Initialises the VarVariant.

Parameters
typedefines which alternative to use for the variant {Double_t, Float_t, Int_t, Bool_t}
varVariantTargetvariant to initialise.

Definition at line 593 of file Dataset.cc.

594 {
595 if (type == "Double_t")
596 varVariantTarget = 0.0;
597 else if (type == "Float_t")
598 varVariantTarget = 0.0f;
599 else if (type == "Int_t")
600 varVariantTarget = 0;
601 else if (type == "Bool_t")
602 varVariantTarget = false;
603 else {
604 B2FATAL("Unknown root input type: " << type);
605 throw std::runtime_error("Unknown root input type: " + type);
606 }
607 }

◆ loadEvent()

void loadEvent ( unsigned int  event)
overridevirtual

Load the event number iEvent from the TTree.

Parameters
eventevent number to load

Implements Dataset.

Definition at line 386 of file Dataset.cc.

387 {
388 if (m_tree->GetEntry(event, 0) == 0) {
389 B2ERROR("Error during loading entry from chain");
390 }
391
392 for (unsigned int i = 0; i < m_input_variant.size(); i++) {
394 }
395 for (unsigned int i = 0; i < m_spectators_variant.size(); i++) {
397 }
401 }
int m_signal_class
Signal class which is used as signal in a classification problem.
Definition: Options.h:88
float castVarVariantToFloat(RootDatasetVarVariant &) const
Casts a VarVariant which can contain <double,int,bool,float> to float.
Definition: Dataset.cc:371

◆ setBranchAddresses()

void setBranchAddresses ( )
private

Sets the branch addresses of all features, weight and target again.

Definition at line 563 of file Dataset.cc.

564 {
565 // Deactivate all branches by default
566 m_tree->SetBranchStatus("*", false);
567
569 if (m_general_options.m_weight_variable == "__weight__") {
570 if (checkForBranch(m_tree, "__weight__")) {
571 m_tree->SetBranchStatus("__weight__", true);
572 std::string typeLabel_weight = "weight";
573 std::string weight_string = "__weight__";
574 setScalarVariableAddressVariant(typeLabel_weight, weight_string, m_weight_variant);
575 } else {
576 m_weight_variant = 1.0f;
577 }
578 } else {
579 std::string typeLabel_weight = "weight";
581 }
582 }
583
584 std::string typeLabel_target = "target";
586 std::string typeLabel_feature = "feature";
588 std::string typeLabel_spectator = "spectator";
590 }
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
void setScalarVariableAddressVariant(const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
sets the branch address for a scalar variable to a given target
Definition: Dataset.cc:531
void setVectorVariableAddressVariant(const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
sets the branch address for a vector of VarVariant to a given target
Definition: Dataset.cc:555

◆ setRootInputType()

void setRootInputType ( )
private

Tries to infer the data-type of the spectator and feature variables in a root file.

Definition at line 627 of file Dataset.cc.

628 {
629 // set target variable
631
632 // set feature variables
633 for (unsigned int i = 0; i < m_general_options.m_variables.size(); i++) {
634 auto variable = m_general_options.m_variables[i];
636 }
637
638 // set spectator variables
639 for (unsigned int i = 0; i < m_general_options.m_spectators.size(); i++) {
640 auto variable = m_general_options.m_spectators[i];
642 }
643
644 // set weight variable - bit more tricky as we allow it to not be set or to not be present.
646 m_weight_variant = 1.0f;
647 B2INFO("No weight variable provided. The weight will be set to 1.");
648 } else {
649 if (m_general_options.m_weight_variable == "__weight__") {
650 if (checkForBranch(m_tree, "__weight__")) {
651 m_tree->SetBranchStatus("__weight__", true);
653 } else {
654 B2INFO("Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
655 "weight variable to an empty string if you don't need it.");
656 m_weight_variant = 1.0f;
657 }
658 } else {
660 }
661 }
662 }
void initialiseVarVariantForBranch(const std::string, RootDatasetVarVariant &)
Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correc...
Definition: Dataset.cc:610

◆ setScalarVariableAddress()

void setScalarVariableAddress ( const std::string &  variableType,
const std::string &  variableName,
T &  variableTarget 
)
private

sets the branch address for a scalar variable to a given target

Template Parameters
Ttarget type (float, double)
Parameters
variableTypedefines {feature, weights, spectator, target}
variableNamename of the variable, usually defined in general_options
variableTargetvariable, the address is set to

Definition at line 510 of file Dataset.cc.

512 {
513 if (not variableName.empty()) {
514 if (checkForBranch(m_tree, variableName)) {
515 m_tree->SetBranchStatus(variableName.c_str(), true);
516 m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
517 } else {
519 m_tree->SetBranchStatus(Belle2::MakeROOTCompatible::makeROOTCompatible(variableName).c_str(), true);
520 m_tree->SetBranchAddress(Belle2::MakeROOTCompatible::makeROOTCompatible(variableName).c_str(), &variableTarget);
521 } else {
522 B2ERROR("Couldn't find given " << variableType << " variable named " << variableName <<
523 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
524 throw std::runtime_error("Couldn't find given " + variableType + " variable named " + variableName +
525 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
526 }
527 }
528 }
529 }

◆ setScalarVariableAddressVariant()

void setScalarVariableAddressVariant ( const std::string &  variableType,
const std::string &  variableName,
RootDatasetVarVariant variableTarget 
)
private

sets the branch address for a scalar variable to a given target

Parameters
variableTypedefines {feature, weights, spectator, target}
variableNamename of the variable, usually defined in general_options
variableTargetvariable, the address is set to

Definition at line 531 of file Dataset.cc.

533 {
534 if (std::holds_alternative<double>(varVariantTarget))
535 setScalarVariableAddress(variableType, variableName, std::get<double>(varVariantTarget));
536 else if (std::holds_alternative<float>(varVariantTarget))
537 setScalarVariableAddress(variableType, variableName, std::get<float>(varVariantTarget));
538 else if (std::holds_alternative<int>(varVariantTarget))
539 setScalarVariableAddress(variableType, variableName, std::get<int>(varVariantTarget));
540 else if (std::holds_alternative<bool>(varVariantTarget))
541 setScalarVariableAddress(variableType, variableName, std::get<bool>(varVariantTarget));
542 else
543 B2FATAL("Variable type for branch " << variableName << " not supported!");
544 }
void setScalarVariableAddress(const std::string &variableType, const std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
Definition: Dataset.cc:510

◆ setVectorVariableAddress()

void setVectorVariableAddress ( const std::string &  variableType,
const std::vector< std::string > &  variableName,
T &  variableTargets 
)
private

sets the branch address for a vector variable to a given target

Template Parameters
Ttarget type (std::vector<float>, std::vector<double>)
Parameters
variableTypedefines {feature, weights, spectator, target}
variableNamenames of the variable, usually defined in general_options
variableTargetsvariables, the address is set to

Definition at line 547 of file Dataset.cc.

549 {
550 for (unsigned int i = 0; i < variableNames.size(); ++i)
551 ROOTDataset::setScalarVariableAddress(variableType, variableNames[i], variableTargets[i]);
552 }

◆ setVectorVariableAddressVariant()

void setVectorVariableAddressVariant ( const std::string &  variableType,
const std::vector< std::string > &  variableName,
std::vector< RootDatasetVarVariant > &  varVariantTargets 
)
private

sets the branch address for a vector of VarVariant to a given target

Parameters
variableTypedefines {feature, weights, spectator, target}
variableNamenames of the variable, usually defined in general_options
varVariantTargetsvariables, the address is set to

Definition at line 555 of file Dataset.cc.

557 {
558 for (unsigned int i = 0; i < variableNames.size(); ++i) {
559 ROOTDataset::setScalarVariableAddressVariant(variableType, variableNames[i], varVariantTargets[i]);
560 }
561 }

Member Data Documentation

◆ m_general_options

GeneralOptions m_general_options
inherited

GeneralOptions passed to this dataset.

Definition at line 122 of file Dataset.h.

◆ m_input

std::vector<float> m_input
inherited

Contains all feature values of the currently loaded event.

Definition at line 123 of file Dataset.h.

◆ m_input_variant

std::vector<RootDatasetVarVariant> m_input_variant
protected

Contains all feature values of the currently loaded event.

Definition at line 409 of file Dataset.h.

◆ m_isSignal

bool m_isSignal
inherited

Defines if the currently loaded event is signal or background.

Definition at line 127 of file Dataset.h.

◆ m_spectators

std::vector<float> m_spectators
inherited

Contains all spectators values of the currently loaded event.

Definition at line 124 of file Dataset.h.

◆ m_spectators_variant

std::vector<RootDatasetVarVariant> m_spectators_variant
protected

Contains all spectators values of the currently loaded event.

Definition at line 411 of file Dataset.h.

◆ m_target

float m_target
inherited

Contains the target value of the currently loaded event.

Definition at line 126 of file Dataset.h.

◆ m_target_variant

RootDatasetVarVariant m_target_variant
protected

Contains the target value of the currently loaded event.

Definition at line 413 of file Dataset.h.

◆ m_tree

TChain* m_tree = nullptr
protected

Pointer to the TChain containing the data.

Definition at line 408 of file Dataset.h.

◆ m_weight

float m_weight
inherited

Contains the weight of the currently loaded event.

Definition at line 125 of file Dataset.h.

◆ m_weight_variant

RootDatasetVarVariant m_weight_variant
protected

Contains the weight of the currently loaded event.

Definition at line 412 of file Dataset.h.


The documentation for this class was generated from the following files: