Belle II Software  release-05-01-25
Dataset.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck *
7  * Jochen Gemmler *
8  * *
9  * This software is provided "as is" without any warranty. *
10  **************************************************************************/
11 
12 #include <mva/interface/Dataset.h>
13 
14 #include <framework/utilities/MakeROOTCompatible.h>
15 #include <framework/logging/Logger.h>
16 #include <framework/io/RootIOUtilities.h>
17 
18 #include <TLeaf.h>
19 
20 #include <boost/filesystem/operations.hpp>
21 
22 namespace Belle2 {
27  namespace MVA {
28 
29  Dataset::Dataset(const GeneralOptions& general_options) : m_general_options(general_options)
30  {
31  m_input.resize(m_general_options.m_variables.size(), 0);
32  m_spectators.resize(m_general_options.m_spectators.size(), 0);
33  m_target = 0.0;
34  m_weight = 1.0;
35  m_isSignal = false;
36  }
37 
39  {
40 
41  double signal_weight_sum = 0;
42  double weight_sum = 0;
43  for (unsigned int i = 0; i < getNumberOfEvents(); ++i) {
44  loadEvent(i);
45  weight_sum += m_weight;
46  if (m_isSignal)
47  signal_weight_sum += m_weight;
48  }
49  return signal_weight_sum / weight_sum;
50 
51  }
52 
53  unsigned int Dataset::getFeatureIndex(const std::string& feature)
54  {
55 
56  auto it = std::find(m_general_options.m_variables.begin(), m_general_options.m_variables.end(), feature);
57  if (it == m_general_options.m_variables.end()) {
58  B2ERROR("Unknown feature named " << feature);
59  return 0;
60  }
61  return std::distance(m_general_options.m_variables.begin(), it);
62 
63  }
64 
65  unsigned int Dataset::getSpectatorIndex(const std::string& spectator)
66  {
67 
68  auto it = std::find(m_general_options.m_spectators.begin(), m_general_options.m_spectators.end(), spectator);
69  if (it == m_general_options.m_spectators.end()) {
70  B2ERROR("Unknown spectator named " << spectator);
71  return 0;
72  }
73  return std::distance(m_general_options.m_spectators.begin(), it);
74 
75  }
76 
77  std::vector<float> Dataset::getFeature(unsigned int iFeature)
78  {
79 
80  std::vector<float> result(getNumberOfEvents());
81  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
82  loadEvent(iEvent);
83  result[iEvent] = m_input[iFeature];
84  }
85  return result;
86 
87  }
88 
89  std::vector<float> Dataset::getSpectator(unsigned int iSpectator)
90  {
91 
92  std::vector<float> result(getNumberOfEvents());
93  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
94  loadEvent(iEvent);
95  result[iEvent] = m_spectators[iSpectator];
96  }
97  return result;
98 
99  }
100 
101  std::vector<float> Dataset::getWeights()
102  {
103 
104  std::vector<float> result(getNumberOfEvents());
105  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
106  loadEvent(iEvent);
107  result[iEvent] = m_weight;
108  }
109  return result;
110 
111  }
112 
113  std::vector<float> Dataset::getTargets()
114  {
115 
116  std::vector<float> result(getNumberOfEvents());
117  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
118  loadEvent(iEvent);
119  result[iEvent] = m_target;
120  }
121  return result;
122 
123  }
124 
125  std::vector<bool> Dataset::getSignals()
126  {
127 
128  std::vector<bool> result(getNumberOfEvents());
129  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
130  loadEvent(iEvent);
131  result[iEvent] = m_isSignal;
132  }
133  return result;
134 
135  }
136 
137 
138  SingleDataset::SingleDataset(const GeneralOptions& general_options, const std::vector<float>& input, float target,
139  const std::vector<float>& spectators) : Dataset(general_options)
140  {
141  m_input = input;
142  m_spectators = spectators;
143  m_target = target;
144  m_weight = 1.0;
145  m_isSignal = std::lround(target) == m_general_options.m_signal_class;
146  }
147 
148  MultiDataset::MultiDataset(const GeneralOptions& general_options, const std::vector<std::vector<float>>& input,
149  const std::vector<std::vector<float>>& spectators,
150  const std::vector<float>& targets, const std::vector<float>& weights) : Dataset(general_options), m_matrix(input),
151  m_spectator_matrix(spectators),
152  m_targets(targets), m_weights(weights)
153  {
154 
155  if (m_targets.size() > 0 and m_matrix.size() != m_targets.size()) {
156  B2ERROR("Feature matrix and target vector need same number of elements in MultiDataset, got " << m_targets.size() << " and " <<
157  m_matrix.size());
158  }
159  if (m_weights.size() > 0 and m_matrix.size() != m_weights.size()) {
160  B2ERROR("Feature matrix and weight vector need same number of elements in MultiDataset, got " << m_weights.size() << " and " <<
161  m_matrix.size());
162  }
163  if (m_spectator_matrix.size() > 0 and m_matrix.size() != m_spectator_matrix.size()) {
164  B2ERROR("Feature matrix and spectator matrix need same number of elements in MultiDataset, got " << m_spectator_matrix.size() <<
165  " and " <<
166  m_matrix.size());
167  }
168  }
169 
170 
171  void MultiDataset::loadEvent(unsigned int iEvent)
172  {
173  m_input = m_matrix[iEvent];
174 
175  if (m_spectator_matrix.size() > 0) {
177  }
178 
179  if (m_targets.size() > 0) {
180  m_target = m_targets[iEvent];
182  }
183 
184  if (m_weights.size() > 0)
185  m_weight = m_weights[iEvent];
186 
187  }
188 
189  SubDataset::SubDataset(const GeneralOptions& general_options, const std::vector<bool>& events,
190  Dataset& dataset) : Dataset(general_options), m_dataset(dataset)
191  {
192 
193  for (auto& v : m_general_options.m_variables) {
194  auto it = std::find(m_dataset.m_general_options.m_variables.begin(), m_dataset.m_general_options.m_variables.end(), v);
195  if (it == m_dataset.m_general_options.m_variables.end()) {
196  B2ERROR("Couldn't find variable " << v << " in GeneralOptions");
197  throw std::runtime_error("Couldn't find variable " + v + " in GeneralOptions");
198  }
199  m_feature_indices.push_back(it - m_dataset.m_general_options.m_variables.begin());
200  }
201 
202  for (auto& v : m_general_options.m_spectators) {
203  auto it = std::find(m_dataset.m_general_options.m_spectators.begin(), m_dataset.m_general_options.m_spectators.end(), v);
204  if (it == m_dataset.m_general_options.m_spectators.end()) {
205  B2ERROR("Couldn't find spectator " << v << " in GeneralOptions");
206  throw std::runtime_error("Couldn't find spectator " + v + " in GeneralOptions");
207  }
208  m_spectator_indices.push_back(it - m_dataset.m_general_options.m_spectators.begin());
209  }
210 
211  if (events.size() > 0)
212  m_use_event_indices = true;
213 
214  if (m_use_event_indices) {
215  m_event_indices.resize(dataset.getNumberOfEvents());
216  unsigned int n_events = 0;
217  for (unsigned int iEvent = 0; iEvent < dataset.getNumberOfEvents(); ++iEvent) {
218  if (events.size() == 0 or events[iEvent]) {
219  m_event_indices[n_events] = iEvent;
220  n_events++;
221  }
222  }
223  m_event_indices.resize(n_events);
224  }
225 
226  }
227 
228  void SubDataset::loadEvent(unsigned int iEvent)
229  {
230  unsigned int index = iEvent;
232  index = m_event_indices[iEvent];
233  m_dataset.loadEvent(index);
237 
238  for (unsigned int iFeature = 0; iFeature < m_input.size(); ++iFeature) {
239  m_input[iFeature] = m_dataset.m_input[m_feature_indices[iFeature]];
240  }
241 
242  for (unsigned int iSpectator = 0; iSpectator < m_spectators.size(); ++iSpectator) {
243  m_spectators[iSpectator] = m_dataset.m_spectators[m_spectator_indices[iSpectator]];
244  }
245 
246  }
247 
248  std::vector<float> SubDataset::getFeature(unsigned int iFeature)
249  {
250 
251  auto v = m_dataset.getFeature(m_feature_indices[iFeature]);
252  if (not m_use_event_indices)
253  return v;
254  std::vector<float> result(m_event_indices.size());
255  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
256  result[iEvent] = v[m_event_indices[iEvent]];
257  }
258  return result;
259 
260  }
261 
262  std::vector<float> SubDataset::getSpectator(unsigned int iSpectator)
263  {
264 
265  auto v = m_dataset.getSpectator(m_spectator_indices[iSpectator]);
266  if (not m_use_event_indices)
267  return v;
268  std::vector<float> result(m_event_indices.size());
269  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
270  result[iEvent] = v[m_event_indices[iEvent]];
271  }
272  return result;
273 
274  }
275 
276  CombinedDataset::CombinedDataset(const GeneralOptions& general_options, Dataset& signal_dataset,
277  Dataset& background_dataset) : Dataset(general_options), m_signal_dataset(signal_dataset),
278  m_background_dataset(background_dataset) { }
279 
280  void CombinedDataset::loadEvent(unsigned int iEvent)
281  {
282  if (iEvent < m_signal_dataset.getNumberOfEvents()) {
283  m_signal_dataset.loadEvent(iEvent);
284  m_target = 1.0;
285  m_isSignal = true;
289  } else {
291  m_target = 0.0;
292  m_isSignal = false;
296  }
297  }
298 
299  std::vector<float> CombinedDataset::getFeature(unsigned int iFeature)
300  {
301 
302  auto s = m_signal_dataset.getFeature(iFeature);
303  auto b = m_background_dataset.getFeature(iFeature);
304  s.insert(s.end(), b.begin(), b.end());
305  return s;
306 
307  }
308 
309  std::vector<float> CombinedDataset::getSpectator(unsigned int iSpectator)
310  {
311 
312  auto s = m_signal_dataset.getSpectator(iSpectator);
313  auto b = m_background_dataset.getSpectator(iSpectator);
314  s.insert(s.end(), b.begin(), b.end());
315  return s;
316 
317  }
318 
319  ROOTDataset::ROOTDataset(const GeneralOptions& general_options) : Dataset(general_options)
320  {
321  m_input_double.resize(m_general_options.m_variables.size(), 0);
322  m_spectators_double.resize(m_general_options.m_spectators.size(), 0);
323  m_target_double = 0.0;
324  m_weight_double = 1.0;
325 
326  for (const auto& variable : general_options.m_variables)
327  for (const auto& spectator : general_options.m_spectators)
328  if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
329  B2ERROR("Interface doesn't support variable more then one time in either spectators, variables or target variable");
330  throw std::runtime_error("Interface doesn't support variable more then one time in either spectators, variables or target variable");
331  }
332 
333  std::vector<std::string> filenames;
334  for (const auto& filename : m_general_options.m_datafiles) {
335  if (boost::filesystem::exists(filename)) {
336  filenames.push_back(filename);
337  } else {
339  filenames.insert(filenames.end(), temp.begin(), temp.end());
340  }
341  }
342  if (filenames.empty()) {
343  B2ERROR("Found no valid filenames in GeneralOptions");
344  throw std::runtime_error("Found no valid filenames in GeneralOptions");
345  }
346 
347  //Open TFile
348  TDirectory* dir = gDirectory;
349  for (const auto& filename : filenames) {
350  if (not boost::filesystem::exists(filename)) {
351  B2ERROR("Error given ROOT file dies not exists " << filename);
352  throw std::runtime_error("Error during open of ROOT file named " + filename);
353  }
354 
355  TFile* f = TFile::Open(filename.c_str(), "READ");
356  if (!f or f->IsZombie() or not f->IsOpen()) {
357  B2ERROR("Error during open of ROOT file named " << filename);
358  throw std::runtime_error("Error during open of ROOT file named " + filename);
359  }
360  delete f;
361  }
362  dir->cd();
363 
364  m_tree = new TChain(m_general_options.m_treename.c_str());
365  for (const auto& filename : filenames) {
366  //nentries = -1 forces AddFile() to read headers
367  if (!m_tree->AddFile(filename.c_str(), -1)) {
368  B2ERROR("Error during open of ROOT file named " << filename << " cannot retreive tree named " <<
370  throw std::runtime_error("Error during open of ROOT file named " + filename + " cannot retreive tree named " +
372  }
373  }
374  setRootInputType();
375  setBranchAddresses();
376  }
377 
378  void ROOTDataset::loadEvent(unsigned int event)
379  {
380  if (m_tree->GetEntry(event, 0) == 0) {
381  B2ERROR("Error during loading entry from chain");
382  }
383  if (m_isDoubleInputType) {
384  m_weight = (float) m_weight_double;
385  m_target = (float) m_target_double;
386  for (unsigned int i = 0; i < m_input_double.size(); i++)
387  m_input[i] = (float) m_input_double[i];
388  for (unsigned int i = 0; i < m_spectators_double.size(); i++)
389  m_spectators[i] = (float) m_spectators_double[i];
390  }
391 
393  }
394 
395  std::vector<float> ROOTDataset::getWeights()
396  {
398  if (branchName.empty()) {
399  B2INFO("No TBranch name given for weights. Using 1s as default weights.");
400  int nentries = getNumberOfEvents();
401  std::vector<float> values(nentries, 1.);
402  return values;
403  }
404  if (branchName == "__weight__") {
405  if (!checkForBranch(m_tree, "__weight__")) {
406  B2INFO("No default weight branch with name __weight__ found. Using 1s as default weights.");
407  int nentries = getNumberOfEvents();
408  std::vector<float> values(nentries, 1.);
409  return values;
410  }
411  }
412 
413  std::string typeName = "weights";
414 
416  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_weight_double);
417 
418  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_weight);
419  }
420 
421  std::vector<float> ROOTDataset::getFeature(unsigned int iFeature)
422  {
423  if (iFeature >= getNumberOfFeatures()) {
424  B2ERROR("Feature index " << iFeature << " is out of bounds of given number of features: "
425  << getNumberOfFeatures());
426  }
427 
428  std::string branchName = Belle2::makeROOTCompatible(m_general_options.m_variables[iFeature]);
429  std::string typeName = "features";
430 
432  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_input_double[iFeature]);
433 
434  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_input[iFeature]);
435  }
436 
437  std::vector<float> ROOTDataset::getSpectator(unsigned int iSpectator)
438  {
439  if (iSpectator >= getNumberOfSpectators()) {
440  B2ERROR("Spectator index " << iSpectator << " is out of bounds of given number of spectators: "
441  << getNumberOfSpectators());
442  }
443 
444  std::string branchName = Belle2::makeROOTCompatible(m_general_options.m_spectators[iSpectator]);
445  std::string typeName = "spectators";
446 
448  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_spectators_double[iSpectator]);
449 
450  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_spectators[iSpectator]);
451  }
452 
454  {
455  delete m_tree;
456  m_tree = nullptr;
457  }
458 
459  template<class T>
460  std::vector<float> ROOTDataset::getVectorFromTTree(std::string& variableType, std::string& branchName,
461  T& memberVariableTarget)
462  {
463  int nentries = getNumberOfEvents();
464  std::vector<float> values(nentries);
465 
466  // Float or Double to be filled
467  T object;
468  auto currentTreeNumber = m_tree->GetTreeNumber();
469  TBranch* branch = m_tree->GetBranch(branchName.c_str());
470  if (not branch) {
471  B2ERROR("TBranch for " + variableType + " named '" << branchName.c_str() << "' does not exist!");
472  }
473  branch->SetAddress(&object);
474  for (int i = 0; i < nentries; ++i) {
475  auto entry = m_tree->LoadTree(i);
476  if (entry < 0) {
477  B2ERROR("Error during loading root tree from chain, error code: " << entry);
478  }
479  // if current tree changed we have to update the branch
480  if (currentTreeNumber != m_tree->GetTreeNumber()) {
481  currentTreeNumber = m_tree->GetTreeNumber();
482  branch = m_tree->GetBranch(branchName.c_str());
483  branch->SetAddress(&object);
484  }
485  branch->GetEntry(entry);
486  values[i] = object;
487  }
488  // Reset branch to correct input address, just to be sure
489  m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
490  return values;
491  }
492 
493  bool ROOTDataset::checkForBranch(TTree* tree, const std::string& branchname) const
494  {
495  auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
496  return branch != nullptr;
497 
498  }
499 
500  template<class T>
501  void ROOTDataset::setScalarVariableAddress(std::string& variableType, std::string& variableName,
502  T& variableTarget)
503  {
504  if (not variableName.empty()) {
505  if (checkForBranch(m_tree, variableName)) {
506  m_tree->SetBranchStatus(variableName.c_str(), true);
507  m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
508  } else {
509  if (checkForBranch(m_tree, Belle2::makeROOTCompatible(variableName))) {
510  m_tree->SetBranchStatus(Belle2::makeROOTCompatible(variableName).c_str(), true);
511  m_tree->SetBranchAddress(Belle2::makeROOTCompatible(variableName).c_str(), &variableTarget);
512  } else {
513  B2ERROR("Couldn't find given " << variableType << " variable named " << variableName <<
514  " (I tried also using makeROOTCompatible)");
515  throw std::runtime_error("Couldn't find given " + variableType + " variable named " + variableName +
516  " (I tried also using makeROOTCompatible)");
517  }
518  }
519  }
520  }
521 
522  template<class T>
523  void ROOTDataset::setVectorVariableAddress(std::string& variableType, std::vector<std::string>& variableNames,
524  T& variableTargets)
525  {
526  for (unsigned int i = 0; i < variableNames.size(); ++i)
527  ROOTDataset::setScalarVariableAddress(variableType, variableNames[i], variableTargets[i]);
528  }
529 
531  {
532  // Deactivate all branches by default
533  m_tree->SetBranchStatus("*", false);
534  std::string typeName;
535 
536  if (m_general_options.m_weight_variable.empty()) {
537  m_weight = 1;
538  m_weight_double = 1;
539  B2INFO("No weight variable provided. The weight will be set to 1.");
540  }
541 
542  if (m_general_options.m_weight_variable == "__weight__") {
543  if (checkForBranch(m_tree, "__weight__")) {
544  m_tree->SetBranchStatus("__weight__", true);
546  m_tree->SetBranchAddress("__weight__", &m_weight_double);
547  else
548  m_tree->SetBranchAddress("__weight__", &m_weight);
549  } else {
550  B2INFO("Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
551  "weight variable to an empty string if you don't need it.");
552  m_weight = 1;
553  m_weight_double = 1;
554  }
555  } else if (m_isDoubleInputType) {
556  typeName = "weight";
558  } else {
559  typeName = "weight";
561  }
562 
563  if (m_isDoubleInputType) {
564  typeName = "target";
566  typeName = "feature";
568  typeName = "spectator";
570  } else {
571  typeName = "target";
573  typeName = "feature";
575  typeName = "spectator";
577  }
578  }
579 
580 
582  {
583  std::string control_variable;
584  for (auto& variable : m_general_options.m_variables) {
585  if (checkForBranch(m_tree, variable))
586  control_variable = variable;
587  else if (checkForBranch(m_tree, Belle2::makeROOTCompatible(variable)))
588  control_variable = Belle2::makeROOTCompatible(variable);
589  if (not control_variable.empty()) {
590  TBranch* branch = m_tree->GetBranch(control_variable.c_str());
591  TLeaf* leaf = branch->GetLeaf(control_variable.c_str());
592  std::string type_name = leaf->GetTypeName();
593  if (type_name == "Double_t")
594  m_isDoubleInputType = true;
595  else if (type_name == "Float_t")
596  m_isDoubleInputType = false;
597  else {
598  B2FATAL("Unknown root input type: " << type_name);
599  throw std::runtime_error("Unknown root input type: " + type_name);
600  }
601  return;
602  }
603  }
604  B2FATAL("No valid feature was found. Check your input features.");
605  throw std::runtime_error("No valid feature was found. Check your input features.");
606  }
607 
608 
609  }
611 }
Belle2::MVA::CombinedDataset::loadEvent
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Definition: Dataset.cc:289
Belle2::MVA::ROOTDataset::m_tree
TChain * m_tree
Pointer to the TChain containing the data.
Definition: Dataset.h:457
Belle2::MVA::ROOTDataset::getFeature
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:430
Belle2::MVA::SingleDataset::SingleDataset
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
Definition: Dataset.cc:147
Belle2::MVA::Dataset::m_general_options
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
Definition: Dataset.h:123
Belle2::MVA::SubDataset::m_use_event_indices
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
Definition: Dataset.h:279
Belle2::MVA::Dataset
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:34
Belle2::MVA::SubDataset::getFeature
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:257
Belle2::MVA::Dataset::loadEvent
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
Belle2::MVA::CombinedDataset::m_background_dataset
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
Definition: Dataset.h:340
Belle2::MVA::Dataset::getFeature
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:86
Belle2::MVA::SubDataset::getNumberOfEvents
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
Definition: Dataset.h:258
Belle2::MVA::ROOTDataset::m_isDoubleInputType
bool m_isDoubleInputType
Defines the expected datatype in the ROOT file.
Definition: Dataset.h:458
Belle2::MVA::Dataset::getSpectator
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
Definition: Dataset.cc:98
Belle2::MVA::ROOTDataset::ROOTDataset
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
Definition: Dataset.cc:328
Belle2::MVA::SubDataset::m_spectator_indices
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
Definition: Dataset.h:283
Belle2::MVA::ROOTDataset::m_spectators_double
std::vector< double > m_spectators_double
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:460
Belle2::MVA::Dataset::getWeights
virtual std::vector< float > getWeights()
Returns all weights.
Definition: Dataset.cc:110
Belle2::MVA::ROOTDataset::setVectorVariableAddress
void setVectorVariableAddress(std::string &variableType, std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
Definition: Dataset.cc:532
Belle2::MVA::GeneralOptions::m_weight_variable
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:92
Belle2::MVA::ROOTDataset::loadEvent
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
Definition: Dataset.cc:387
Belle2::MVA::MultiDataset::m_spectator_matrix
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
Definition: Dataset.h:224
Belle2::MVA::ROOTDataset::getWeights
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
Definition: Dataset.cc:404
Belle2::MVA::SubDataset::SubDataset
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
Definition: Dataset.cc:198
Belle2::MVA::SubDataset::m_event_indices
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
Definition: Dataset.h:285
Belle2::MVA::MultiDataset::m_targets
std::vector< float > m_targets
target vector
Definition: Dataset.h:225
Belle2::MVA::SubDataset::m_feature_indices
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
Definition: Dataset.h:281
Belle2::MVA::Dataset::getSignals
virtual std::vector< bool > getSignals()
Returns all is Signals.
Definition: Dataset.cc:134
Belle2::MVA::ROOTDataset::setScalarVariableAddress
void setScalarVariableAddress(std::string &variableType, std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
Definition: Dataset.cc:510
Belle2::MVA::CombinedDataset::CombinedDataset
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
Definition: Dataset.cc:285
Belle2::MVA::ROOTDataset::getNumberOfEvents
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
Definition: Dataset.h:371
Belle2::MVA::GeneralOptions::m_treename
std::string m_treename
Name of the TTree inside the datafile containing the training data.
Definition: Options.h:87
Belle2::MVA::GeneralOptions::m_spectators
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:89
Belle2::MVA::ROOTDataset::m_weight_double
double m_weight_double
Contains the weight of the currently loaded event.
Definition: Dataset.h:461
Belle2::makeROOTCompatible
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
Definition: MakeROOTCompatible.cc:74
Belle2::MVA::ROOTDataset::getNumberOfFeatures
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
Definition: Dataset.h:361
Belle2::MVA::MultiDataset::MultiDataset
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float >> &input, const std::vector< std::vector< float >> &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
Definition: Dataset.cc:157
Belle2::RootIOUtilities::expandWordExpansions
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.
Definition: RootIOUtilities.cc:107
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::Dataset::m_input
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Definition: Dataset.h:124
Belle2::MVA::GeneralOptions::m_target_variable
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:91
Belle2::MVA::GeneralOptions::m_signal_class
int m_signal_class
Signal class which is used as signal in a classification problem.
Definition: Options.h:90
Belle2::MVA::Dataset::m_spectators
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:125
Belle2::MVA::GeneralOptions::m_variables
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:88
Belle2::MVA::SubDataset::m_dataset
Dataset & m_dataset
Reference to the wrapped dataset.
Definition: Dataset.h:286
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::MVA::MultiDataset::loadEvent
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
Definition: Dataset.cc:180
Belle2::MVA::CombinedDataset::getSpectator
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:318
Belle2::MVA::CombinedDataset::m_signal_dataset
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
Definition: Dataset.h:339
Belle2::MVA::ROOTDataset::checkForBranch
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
Definition: Dataset.cc:502
Belle2::MVA::GeneralOptions::m_datafiles
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
Definition: Options.h:86
alignment.constraints_generator.filename
filename
File name.
Definition: constraints_generator.py:224
Belle2::MVA::ROOTDataset::~ROOTDataset
virtual ~ROOTDataset()
Virtual destructor.
Definition: Dataset.cc:462
Belle2::MVA::MultiDataset::m_weights
std::vector< float > m_weights
weight vector
Definition: Dataset.h:226
Belle2::MVA::ROOTDataset::getSpectator
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
Definition: Dataset.cc:446
Belle2::MVA::Dataset::getTargets
virtual std::vector< float > getTargets()
Returns all targets.
Definition: Dataset.cc:122
Belle2::MVA::Dataset::m_weight
float m_weight
Contains the weight of the currently loaded event.
Definition: Dataset.h:126
Belle2::MVA::Dataset::getSpectatorIndex
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
Definition: Dataset.cc:74
Belle2::MVA::Dataset::m_isSignal
bool m_isSignal
Defines if the currently loaded event is signal or background.
Definition: Dataset.h:128
Belle2::MVA::MultiDataset::m_matrix
std::vector< std::vector< float > > m_matrix
Feature matrix.
Definition: Dataset.h:223
Belle2::MVA::Dataset::Dataset
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
Definition: Dataset.cc:38
Belle2::MVA::ROOTDataset::setBranchAddresses
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
Definition: Dataset.cc:539
Belle2::MVA::SubDataset::getSpectator
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:271
Belle2::MVA::ROOTDataset::getNumberOfSpectators
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
Definition: Dataset.h:366
Belle2::MVA::ROOTDataset::getVectorFromTTree
std::vector< float > getVectorFromTTree(std::string &variableType, std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
Definition: Dataset.cc:469
Belle2::MVA::SubDataset::loadEvent
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Definition: Dataset.cc:237
Belle2::MVA::ROOTDataset::m_target_double
double m_target_double
Contains the target value of the currently loaded event.
Definition: Dataset.h:462
Belle2::MVA::Dataset::getNumberOfEvents
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
Belle2::MVA::Dataset::m_target
float m_target
Contains the target value of the currently loaded event.
Definition: Dataset.h:127
Belle2::MVA::CombinedDataset::getFeature
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:308
Belle2::MVA::ROOTDataset::m_input_double
std::vector< double > m_input_double
Contains all feature values of the currently loaded event.
Definition: Dataset.h:459
Belle2::MVA::Dataset::getFeatureIndex
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
Definition: Dataset.cc:62
Belle2::MVA::ROOTDataset::setRootInputType
void setRootInputType()
Tries to infer the data-type of a root file and sets m_isDoubleInputType.
Definition: Dataset.cc:590
Belle2::MVA::Dataset::getSignalFraction
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:47