 |
Belle II Software
release-05-01-25
|
12 #include <mva/interface/Dataset.h>
14 #include <framework/utilities/MakeROOTCompatible.h>
15 #include <framework/logging/Logger.h>
16 #include <framework/io/RootIOUtilities.h>
20 #include <boost/filesystem/operations.hpp>
29 Dataset::Dataset(
const GeneralOptions& general_options) : m_general_options(general_options)
31 m_input.resize(m_general_options.m_variables.size(), 0);
32 m_spectators.resize(m_general_options.m_spectators.size(), 0);
41 double signal_weight_sum = 0;
42 double weight_sum = 0;
49 return signal_weight_sum / weight_sum;
58 B2ERROR(
"Unknown feature named " << feature);
70 B2ERROR(
"Unknown spectator named " << spectator);
83 result[iEvent] =
m_input[iFeature];
139 const std::vector<float>& spectators) :
Dataset(general_options)
149 const std::vector<std::vector<float>>& spectators,
150 const std::vector<float>& targets,
const std::vector<float>& weights) :
Dataset(general_options), m_matrix(input),
151 m_spectator_matrix(spectators),
152 m_targets(targets), m_weights(weights)
155 if (m_targets.size() > 0 and m_matrix.size() != m_targets.size()) {
156 B2ERROR(
"Feature matrix and target vector need same number of elements in MultiDataset, got " << m_targets.size() <<
" and " <<
159 if (m_weights.size() > 0 and m_matrix.size() != m_weights.size()) {
160 B2ERROR(
"Feature matrix and weight vector need same number of elements in MultiDataset, got " << m_weights.size() <<
" and " <<
163 if (m_spectator_matrix.size() > 0 and m_matrix.size() != m_spectator_matrix.size()) {
164 B2ERROR(
"Feature matrix and spectator matrix need same number of elements in MultiDataset, got " << m_spectator_matrix.size() <<
190 Dataset& dataset) : Dataset(general_options), m_dataset(dataset)
194 auto it = std::find(m_dataset.m_general_options.m_variables.begin(), m_dataset.m_general_options.m_variables.end(), v);
195 if (it == m_dataset.m_general_options.m_variables.end()) {
196 B2ERROR(
"Couldn't find variable " << v <<
" in GeneralOptions");
197 throw std::runtime_error(
"Couldn't find variable " + v +
" in GeneralOptions");
199 m_feature_indices.push_back(it - m_dataset.m_general_options.m_variables.begin());
203 auto it = std::find(m_dataset.m_general_options.m_spectators.begin(), m_dataset.m_general_options.m_spectators.end(), v);
204 if (it == m_dataset.m_general_options.m_spectators.end()) {
205 B2ERROR(
"Couldn't find spectator " << v <<
" in GeneralOptions");
206 throw std::runtime_error(
"Couldn't find spectator " + v +
" in GeneralOptions");
208 m_spectator_indices.push_back(it - m_dataset.m_general_options.m_spectators.begin());
211 if (events.size() > 0)
212 m_use_event_indices =
true;
214 if (m_use_event_indices) {
215 m_event_indices.resize(dataset.getNumberOfEvents());
216 unsigned int n_events = 0;
217 for (
unsigned int iEvent = 0; iEvent < dataset.getNumberOfEvents(); ++iEvent) {
218 if (events.size() == 0 or events[iEvent]) {
219 m_event_indices[n_events] = iEvent;
223 m_event_indices.resize(n_events);
230 unsigned int index = iEvent;
238 for (
unsigned int iFeature = 0; iFeature <
m_input.size(); ++iFeature) {
242 for (
unsigned int iSpectator = 0; iSpectator <
m_spectators.size(); ++iSpectator) {
277 Dataset& background_dataset) : Dataset(general_options), m_signal_dataset(signal_dataset),
278 m_background_dataset(background_dataset) { }
304 s.insert(s.end(), b.begin(), b.end());
314 s.insert(s.end(), b.begin(), b.end());
323 m_target_double = 0.0;
324 m_weight_double = 1.0;
326 for (
const auto& variable : general_options.m_variables)
327 for (
const auto& spectator : general_options.m_spectators)
328 if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
329 B2ERROR(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
330 throw std::runtime_error(
"Interface doesn't support variable more then one time in either spectators, variables or target variable");
333 std::vector<std::string> filenames;
335 if (boost::filesystem::exists(filename)) {
336 filenames.push_back(filename);
339 filenames.insert(filenames.end(), temp.begin(), temp.end());
342 if (filenames.empty()) {
343 B2ERROR(
"Found no valid filenames in GeneralOptions");
344 throw std::runtime_error(
"Found no valid filenames in GeneralOptions");
348 TDirectory* dir = gDirectory;
349 for (
const auto& filename : filenames) {
350 if (not boost::filesystem::exists(filename)) {
351 B2ERROR(
"Error given ROOT file dies not exists " << filename);
352 throw std::runtime_error(
"Error during open of ROOT file named " + filename);
355 TFile* f = TFile::Open(
filename.c_str(),
"READ");
356 if (!f or f->IsZombie() or not f->IsOpen()) {
357 B2ERROR(
"Error during open of ROOT file named " << filename);
358 throw std::runtime_error(
"Error during open of ROOT file named " + filename);
365 for (
const auto& filename : filenames) {
367 if (!m_tree->AddFile(
filename.c_str(), -1)) {
368 B2ERROR(
"Error during open of ROOT file named " << filename <<
" cannot retreive tree named " <<
370 throw std::runtime_error(
"Error during open of ROOT file named " + filename +
" cannot retreive tree named " +
375 setBranchAddresses();
380 if (
m_tree->GetEntry(event, 0) == 0) {
381 B2ERROR(
"Error during loading entry from chain");
398 if (branchName.empty()) {
399 B2INFO(
"No TBranch name given for weights. Using 1s as default weights.");
401 std::vector<float> values(nentries, 1.);
404 if (branchName ==
"__weight__") {
406 B2INFO(
"No default weight branch with name __weight__ found. Using 1s as default weights.");
408 std::vector<float> values(nentries, 1.);
413 std::string typeName =
"weights";
424 B2ERROR(
"Feature index " << iFeature <<
" is out of bounds of given number of features: "
429 std::string typeName =
"features";
440 B2ERROR(
"Spectator index " << iSpectator <<
" is out of bounds of given number of spectators: "
445 std::string typeName =
"spectators";
461 T& memberVariableTarget)
464 std::vector<float> values(nentries);
468 auto currentTreeNumber =
m_tree->GetTreeNumber();
469 TBranch* branch =
m_tree->GetBranch(branchName.c_str());
471 B2ERROR(
"TBranch for " + variableType +
" named '" << branchName.c_str() <<
"' does not exist!");
473 branch->SetAddress(&
object);
474 for (
int i = 0; i < nentries; ++i) {
475 auto entry =
m_tree->LoadTree(i);
477 B2ERROR(
"Error during loading root tree from chain, error code: " << entry);
480 if (currentTreeNumber !=
m_tree->GetTreeNumber()) {
481 currentTreeNumber =
m_tree->GetTreeNumber();
482 branch =
m_tree->GetBranch(branchName.c_str());
483 branch->SetAddress(&
object);
485 branch->GetEntry(entry);
489 m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
495 auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
496 return branch !=
nullptr;
504 if (not variableName.empty()) {
506 m_tree->SetBranchStatus(variableName.c_str(),
true);
507 m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
513 B2ERROR(
"Couldn't find given " << variableType <<
" variable named " << variableName <<
514 " (I tried also using makeROOTCompatible)");
515 throw std::runtime_error(
"Couldn't find given " + variableType +
" variable named " + variableName +
516 " (I tried also using makeROOTCompatible)");
526 for (
unsigned int i = 0; i < variableNames.size(); ++i)
533 m_tree->SetBranchStatus(
"*",
false);
534 std::string typeName;
539 B2INFO(
"No weight variable provided. The weight will be set to 1.");
544 m_tree->SetBranchStatus(
"__weight__",
true);
550 B2INFO(
"Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
551 "weight variable to an empty string if you don't need it.");
566 typeName =
"feature";
568 typeName =
"spectator";
573 typeName =
"feature";
575 typeName =
"spectator";
583 std::string control_variable;
586 control_variable = variable;
589 if (not control_variable.empty()) {
590 TBranch* branch =
m_tree->GetBranch(control_variable.c_str());
591 TLeaf* leaf = branch->GetLeaf(control_variable.c_str());
592 std::string type_name = leaf->GetTypeName();
593 if (type_name ==
"Double_t")
595 else if (type_name ==
"Float_t")
598 B2FATAL(
"Unknown root input type: " << type_name);
599 throw std::runtime_error(
"Unknown root input type: " + type_name);
604 B2FATAL(
"No valid feature was found. Check your input features.");
605 throw std::runtime_error(
"No valid feature was found. Check your input features.");
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
TChain * m_tree
Pointer to the TChain containing the data.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
bool m_isDoubleInputType
Defines the expected datatype in the ROOT file.
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
std::vector< double > m_spectators_double
Contains all spectators values of the currently loaded event.
virtual std::vector< float > getWeights()
Returns all weights.
void setVectorVariableAddress(std::string &variableType, std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
std::string m_weight_variable
Weight variable (branch name) defining the weights.
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
std::vector< float > m_targets
target vector
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
virtual std::vector< bool > getSignals()
Returns all is Signals.
void setScalarVariableAddress(std::string &variableType, std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
std::string m_treename
Name of the TTree inside the datafile containing the training data.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
double m_weight_double
Contains the weight of the currently loaded event.
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float >> &input, const std::vector< std::vector< float >> &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.
Abstract base class for different kinds of events.
std::vector< float > m_input
Contains all feature values of the currently loaded event.
std::string m_target_variable
Target variable (branch name) defining the target.
int m_signal_class
Signal class which is used as signal in a classification problem.
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Dataset & m_dataset
Reference to the wrapped dataset.
General options which are shared by all MVA trainings.
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
virtual ~ROOTDataset()
Virtual destructor.
std::vector< float > m_weights
weight vector
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
virtual std::vector< float > getTargets()
Returns all targets.
float m_weight
Contains the weight of the currently loaded event.
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
bool m_isSignal
Defines if the currently loaded event is signal or background.
std::vector< std::vector< float > > m_matrix
Feature matrix.
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
std::vector< float > getVectorFromTTree(std::string &variableType, std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
double m_target_double
Contains the target value of the currently loaded event.
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
float m_target
Contains the target value of the currently loaded event.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
std::vector< double > m_input_double
Contains all feature values of the currently loaded event.
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
void setRootInputType()
Tries to infer the data-type of a root file and sets m_isDoubleInputType.
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.