Belle II Software  light-2212-foldex
Dataset.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/interface/Dataset.h>
10 
11 #include <framework/utilities/MakeROOTCompatible.h>
12 #include <framework/logging/Logger.h>
13 #include <framework/io/RootIOUtilities.h>
14 
15 #include <TLeaf.h>
16 
17 #include <boost/filesystem/operations.hpp>
18 
19 namespace Belle2 {
24  namespace MVA {
25 
26  Dataset::Dataset(const GeneralOptions& general_options) : m_general_options(general_options)
27  {
28  m_input.resize(m_general_options.m_variables.size(), 0);
30  m_target = 0.0;
31  m_weight = 1.0;
32  m_isSignal = false;
33  }
34 
36  {
37 
38  double signal_weight_sum = 0;
39  double weight_sum = 0;
40  for (unsigned int i = 0; i < getNumberOfEvents(); ++i) {
41  loadEvent(i);
42  weight_sum += m_weight;
43  if (m_isSignal)
44  signal_weight_sum += m_weight;
45  }
46  return signal_weight_sum / weight_sum;
47 
48  }
49 
50  unsigned int Dataset::getFeatureIndex(const std::string& feature)
51  {
52 
53  auto it = std::find(m_general_options.m_variables.begin(), m_general_options.m_variables.end(), feature);
54  if (it == m_general_options.m_variables.end()) {
55  B2ERROR("Unknown feature named " << feature);
56  return 0;
57  }
58  return std::distance(m_general_options.m_variables.begin(), it);
59 
60  }
61 
62  unsigned int Dataset::getSpectatorIndex(const std::string& spectator)
63  {
64 
65  auto it = std::find(m_general_options.m_spectators.begin(), m_general_options.m_spectators.end(), spectator);
66  if (it == m_general_options.m_spectators.end()) {
67  B2ERROR("Unknown spectator named " << spectator);
68  return 0;
69  }
70  return std::distance(m_general_options.m_spectators.begin(), it);
71 
72  }
73 
74  std::vector<float> Dataset::getFeature(unsigned int iFeature)
75  {
76 
77  std::vector<float> result(getNumberOfEvents());
78  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
79  loadEvent(iEvent);
80  result[iEvent] = m_input[iFeature];
81  }
82  return result;
83 
84  }
85 
86  std::vector<float> Dataset::getSpectator(unsigned int iSpectator)
87  {
88 
89  std::vector<float> result(getNumberOfEvents());
90  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
91  loadEvent(iEvent);
92  result[iEvent] = m_spectators[iSpectator];
93  }
94  return result;
95 
96  }
97 
98  std::vector<float> Dataset::getWeights()
99  {
100 
101  std::vector<float> result(getNumberOfEvents());
102  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
103  loadEvent(iEvent);
104  result[iEvent] = m_weight;
105  }
106  return result;
107 
108  }
109 
110  std::vector<float> Dataset::getTargets()
111  {
112 
113  std::vector<float> result(getNumberOfEvents());
114  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
115  loadEvent(iEvent);
116  result[iEvent] = m_target;
117  }
118  return result;
119 
120  }
121 
122  std::vector<bool> Dataset::getSignals()
123  {
124 
125  std::vector<bool> result(getNumberOfEvents());
126  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
127  loadEvent(iEvent);
128  result[iEvent] = m_isSignal;
129  }
130  return result;
131 
132  }
133 
134 
135  SingleDataset::SingleDataset(const GeneralOptions& general_options, const std::vector<float>& input, float target,
136  const std::vector<float>& spectators) : Dataset(general_options)
137  {
138  m_input = input;
139  m_spectators = spectators;
140  m_target = target;
141  m_weight = 1.0;
142  m_isSignal = std::lround(target) == m_general_options.m_signal_class;
143  }
144 
145  MultiDataset::MultiDataset(const GeneralOptions& general_options, const std::vector<std::vector<float>>& input,
146  const std::vector<std::vector<float>>& spectators,
147  const std::vector<float>& targets, const std::vector<float>& weights) : Dataset(general_options), m_matrix(input),
148  m_spectator_matrix(spectators),
149  m_targets(targets), m_weights(weights)
150  {
151 
152  if (m_targets.size() > 0 and m_matrix.size() != m_targets.size()) {
153  B2ERROR("Feature matrix and target vector need same number of elements in MultiDataset, got " << m_targets.size() << " and " <<
154  m_matrix.size());
155  }
156  if (m_weights.size() > 0 and m_matrix.size() != m_weights.size()) {
157  B2ERROR("Feature matrix and weight vector need same number of elements in MultiDataset, got " << m_weights.size() << " and " <<
158  m_matrix.size());
159  }
160  if (m_spectator_matrix.size() > 0 and m_matrix.size() != m_spectator_matrix.size()) {
161  B2ERROR("Feature matrix and spectator matrix need same number of elements in MultiDataset, got " << m_spectator_matrix.size() <<
162  " and " <<
163  m_matrix.size());
164  }
165  }
166 
167 
168  void MultiDataset::loadEvent(unsigned int iEvent)
169  {
170  m_input = m_matrix[iEvent];
171 
172  if (m_spectator_matrix.size() > 0) {
174  }
175 
176  if (m_targets.size() > 0) {
177  m_target = m_targets[iEvent];
179  }
180 
181  if (m_weights.size() > 0)
182  m_weight = m_weights[iEvent];
183 
184  }
185 
186  SubDataset::SubDataset(const GeneralOptions& general_options, const std::vector<bool>& events,
187  Dataset& dataset) : Dataset(general_options), m_dataset(dataset)
188  {
189 
190  for (auto& v : m_general_options.m_variables) {
191  auto it = std::find(m_dataset.m_general_options.m_variables.begin(), m_dataset.m_general_options.m_variables.end(), v);
192  if (it == m_dataset.m_general_options.m_variables.end()) {
193  B2ERROR("Couldn't find variable " << v << " in GeneralOptions");
194  throw std::runtime_error("Couldn't find variable " + v + " in GeneralOptions");
195  }
197  }
198 
199  for (auto& v : m_general_options.m_spectators) {
200  auto it = std::find(m_dataset.m_general_options.m_spectators.begin(), m_dataset.m_general_options.m_spectators.end(), v);
201  if (it == m_dataset.m_general_options.m_spectators.end()) {
202  B2ERROR("Couldn't find spectator " << v << " in GeneralOptions");
203  throw std::runtime_error("Couldn't find spectator " + v + " in GeneralOptions");
204  }
206  }
207 
208  if (events.size() > 0)
209  m_use_event_indices = true;
210 
211  if (m_use_event_indices) {
212  m_event_indices.resize(dataset.getNumberOfEvents());
213  unsigned int n_events = 0;
214  for (unsigned int iEvent = 0; iEvent < dataset.getNumberOfEvents(); ++iEvent) {
215  if (events.size() == 0 or events[iEvent]) {
216  m_event_indices[n_events] = iEvent;
217  n_events++;
218  }
219  }
220  m_event_indices.resize(n_events);
221  }
222 
223  }
224 
225  void SubDataset::loadEvent(unsigned int iEvent)
226  {
227  unsigned int index = iEvent;
229  index = m_event_indices[iEvent];
230  m_dataset.loadEvent(index);
234 
235  for (unsigned int iFeature = 0; iFeature < m_input.size(); ++iFeature) {
236  m_input[iFeature] = m_dataset.m_input[m_feature_indices[iFeature]];
237  }
238 
239  for (unsigned int iSpectator = 0; iSpectator < m_spectators.size(); ++iSpectator) {
240  m_spectators[iSpectator] = m_dataset.m_spectators[m_spectator_indices[iSpectator]];
241  }
242 
243  }
244 
245  std::vector<float> SubDataset::getFeature(unsigned int iFeature)
246  {
247 
248  auto v = m_dataset.getFeature(m_feature_indices[iFeature]);
249  if (not m_use_event_indices)
250  return v;
251  std::vector<float> result(m_event_indices.size());
252  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
253  result[iEvent] = v[m_event_indices[iEvent]];
254  }
255  return result;
256 
257  }
258 
259  std::vector<float> SubDataset::getSpectator(unsigned int iSpectator)
260  {
261 
262  auto v = m_dataset.getSpectator(m_spectator_indices[iSpectator]);
263  if (not m_use_event_indices)
264  return v;
265  std::vector<float> result(m_event_indices.size());
266  for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
267  result[iEvent] = v[m_event_indices[iEvent]];
268  }
269  return result;
270 
271  }
272 
273  CombinedDataset::CombinedDataset(const GeneralOptions& general_options, Dataset& signal_dataset,
274  Dataset& background_dataset) : Dataset(general_options), m_signal_dataset(signal_dataset),
275  m_background_dataset(background_dataset) { }
276 
277  void CombinedDataset::loadEvent(unsigned int iEvent)
278  {
279  if (iEvent < m_signal_dataset.getNumberOfEvents()) {
280  m_signal_dataset.loadEvent(iEvent);
281  m_target = 1.0;
282  m_isSignal = true;
286  } else {
288  m_target = 0.0;
289  m_isSignal = false;
293  }
294  }
295 
296  std::vector<float> CombinedDataset::getFeature(unsigned int iFeature)
297  {
298 
299  auto s = m_signal_dataset.getFeature(iFeature);
300  auto b = m_background_dataset.getFeature(iFeature);
301  s.insert(s.end(), b.begin(), b.end());
302  return s;
303 
304  }
305 
306  std::vector<float> CombinedDataset::getSpectator(unsigned int iSpectator)
307  {
308 
309  auto s = m_signal_dataset.getSpectator(iSpectator);
310  auto b = m_background_dataset.getSpectator(iSpectator);
311  s.insert(s.end(), b.begin(), b.end());
312  return s;
313 
314  }
315 
316  ROOTDataset::ROOTDataset(const GeneralOptions& general_options) : Dataset(general_options)
317  {
320  m_weight_double = 1.0;
321  m_target_double = 0.0;
322  m_target_int = 0;
323  m_target_bool = 0;
324 
325  for (const auto& variable : general_options.m_variables)
326  for (const auto& spectator : general_options.m_spectators)
327  if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
328  B2ERROR("Interface doesn't support variable more then one time in either spectators, variables or target variable");
329  throw std::runtime_error("Interface doesn't support variable more then one time in either spectators, variables or target variable");
330  }
331 
332  std::vector<std::string> filenames;
333  for (const auto& filename : m_general_options.m_datafiles) {
334  if (boost::filesystem::exists(filename)) {
335  filenames.push_back(filename);
336  } else {
338  filenames.insert(filenames.end(), temp.begin(), temp.end());
339  }
340  }
341  if (filenames.empty()) {
342  B2ERROR("Found no valid filenames in GeneralOptions");
343  throw std::runtime_error("Found no valid filenames in GeneralOptions");
344  }
345 
346  //Open TFile
347  TDirectory* dir = gDirectory;
348  for (const auto& filename : filenames) {
349  if (not boost::filesystem::exists(filename)) {
350  B2ERROR("Error given ROOT file does not exist " << filename);
351  throw std::runtime_error("Error during open of ROOT file named " + filename);
352  }
353 
354  TFile* f = TFile::Open(filename.c_str(), "READ");
355  if (!f or f->IsZombie() or not f->IsOpen()) {
356  B2ERROR("Error during open of ROOT file named " << filename);
357  throw std::runtime_error("Error during open of ROOT file named " + filename);
358  }
359  delete f;
360  }
361  dir->cd();
362 
363  m_tree = new TChain(m_general_options.m_treename.c_str());
364  for (const auto& filename : filenames) {
365  //nentries = -1 forces AddFile() to read headers
366  if (!m_tree->AddFile(filename.c_str(), -1)) {
367  B2ERROR("Error during open of ROOT file named " << filename << " cannot retrieve tree named " <<
369  throw std::runtime_error("Error during open of ROOT file named " + filename + " cannot retrieve tree named " +
371  }
372  }
376  }
377 
378  void ROOTDataset::loadEvent(unsigned int event)
379  {
380  if (m_tree->GetEntry(event, 0) == 0) {
381  B2ERROR("Error during loading entry from chain");
382  }
383 
384  if (!m_isFloatInputType) {
385  m_weight = (float) m_weight_double;
386 
387  for (unsigned int i = 0; i < m_input_variant.size(); i++) {
388  if (std::holds_alternative<double>(m_input_variant[i]))
389  m_input[i] = (float) std::get<double>(m_input_variant[i]);
390  else if (std::holds_alternative<int>(m_input_variant[i]))
391  m_input[i] = (float) std::get<int>(m_input_variant[i]);
392  else if (std::holds_alternative<bool>(m_input_variant[i]))
393  m_input[i] = (float) std::get<bool>(m_input_variant[i]);
394  }
395  for (unsigned int i = 0; i < m_spectators_variant.size(); i++) {
396  if (std::holds_alternative<double>(m_spectators_variant[i]))
397  m_spectators[i] = (float) std::get<double>(m_spectators_variant[i]);
398  else if (std::holds_alternative<int>(m_spectators_variant[i]))
399  m_spectators[i] = (float) std::get<int>(m_spectators_variant[i]);
400  else if (std::holds_alternative<bool>(m_spectators_variant[i]))
401  m_spectators[i] = (float) std::get<bool>(m_spectators_variant[i]);
402  }
403  }
404 
405  if (m_target_data_type == Variable::Manager::VariableDataType::c_double)
406  m_target = (float) m_target_double;
407  else if (m_target_data_type == Variable::Manager::VariableDataType::c_int)
408  m_target = (float) m_target_int;
409  else if (m_target_data_type == Variable::Manager::VariableDataType::c_bool)
410  m_target = (float) m_target_bool;
411 
413  }
414 
415  std::vector<float> ROOTDataset::getWeights()
416  {
418  if (branchName.empty()) {
419  B2INFO("No TBranch name given for weights. Using 1s as default weights.");
420  int nentries = getNumberOfEvents();
421  std::vector<float> values(nentries, 1.);
422  return values;
423  }
424  if (branchName == "__weight__") {
425  if (!checkForBranch(m_tree, "__weight__")) {
426  B2INFO("No default weight branch with name __weight__ found. Using 1s as default weights.");
427  int nentries = getNumberOfEvents();
428  std::vector<float> values(nentries, 1.);
429  return values;
430  }
431  }
432 
433  std::string typeName = "weights";
434 
435  if (m_isFloatInputType)
436  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_weight);
437  else
438  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_weight_double);
439 
440 
441  }
442 
443  std::vector<float> ROOTDataset::getFeature(unsigned int iFeature)
444  {
445  if (iFeature >= getNumberOfFeatures()) {
446  B2ERROR("Feature index " << iFeature << " is out of bounds of given number of features: "
447  << getNumberOfFeatures());
448  }
449 
451  std::string typeName = "features";
452 
453  if (!m_isFloatInputType) {
454  if (std::holds_alternative<double>(m_input_variant[iFeature]))
455  return ROOTDataset::getVectorFromTTree(typeName, branchName, std::get<double>(m_input_variant[iFeature]));
456  else if (std::holds_alternative<int>(m_input_variant[iFeature]))
457  return ROOTDataset::getVectorFromTTree(typeName, branchName, std::get<int>(m_input_variant[iFeature]));
458  else if (std::holds_alternative<bool>(m_input_variant[iFeature]))
459  return ROOTDataset::getVectorFromTTree(typeName, branchName, std::get<bool>(m_input_variant[iFeature]));
460  }
461 
462  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_input[iFeature]);
463  }
464 
465  std::vector<float> ROOTDataset::getSpectator(unsigned int iSpectator)
466  {
467  if (iSpectator >= getNumberOfSpectators()) {
468  B2ERROR("Spectator index " << iSpectator << " is out of bounds of given number of spectators: "
469  << getNumberOfSpectators());
470  }
471 
473  std::string typeName = "spectators";
474 
475  if (!m_isFloatInputType) {
476  if (std::holds_alternative<double>(m_spectators_variant[iSpectator]))
477  return ROOTDataset::getVectorFromTTree(typeName, branchName, std::get<double>(m_spectators_variant[iSpectator]));
478  else if (std::holds_alternative<int>(m_spectators_variant[iSpectator]))
479  return ROOTDataset::getVectorFromTTree(typeName, branchName, std::get<int>(m_spectators_variant[iSpectator]));
480  else if (std::holds_alternative<bool>(m_spectators_variant[iSpectator]))
481  return ROOTDataset::getVectorFromTTree(typeName, branchName, std::get<bool>(m_spectators_variant[iSpectator]));
482  }
483 
484  return ROOTDataset::getVectorFromTTree(typeName, branchName, m_spectators[iSpectator]);
485  }
486 
488  {
489  delete m_tree;
490  m_tree = nullptr;
491  }
492 
493  template<class T>
494  std::vector<float> ROOTDataset::getVectorFromTTree(std::string& variableType, std::string& branchName,
495  T& memberVariableTarget)
496  {
497  int nentries = getNumberOfEvents();
498  std::vector<float> values(nentries);
499 
500  // Float or Double to be filled
501  T object;
502  auto currentTreeNumber = m_tree->GetTreeNumber();
503  TBranch* branch = m_tree->GetBranch(branchName.c_str());
504  if (not branch) {
505  B2ERROR("TBranch for " + variableType + " named '" << branchName.c_str() << "' does not exist!");
506  }
507  branch->SetAddress(&object);
508  for (int i = 0; i < nentries; ++i) {
509  auto entry = m_tree->LoadTree(i);
510  if (entry < 0) {
511  B2ERROR("Error during loading root tree from chain, error code: " << entry);
512  }
513  // if current tree changed we have to update the branch
514  if (currentTreeNumber != m_tree->GetTreeNumber()) {
515  currentTreeNumber = m_tree->GetTreeNumber();
516  branch = m_tree->GetBranch(branchName.c_str());
517  branch->SetAddress(&object);
518  }
519  branch->GetEntry(entry);
520  values[i] = object;
521  }
522  // Reset branch to correct input address, just to be sure
523  m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
524  return values;
525  }
526 
527  bool ROOTDataset::checkForBranch(TTree* tree, const std::string& branchname) const
528  {
529  auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
530  return branch != nullptr;
531 
532  }
533 
534  template<class T>
535  void ROOTDataset::setScalarVariableAddress(std::string& variableType, std::string& variableName,
536  T& variableTarget)
537  {
538  if (not variableName.empty()) {
539  if (checkForBranch(m_tree, variableName)) {
540  m_tree->SetBranchStatus(variableName.c_str(), true);
541  m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
542  } else {
544  m_tree->SetBranchStatus(Belle2::MakeROOTCompatible::makeROOTCompatible(variableName).c_str(), true);
545  m_tree->SetBranchAddress(Belle2::MakeROOTCompatible::makeROOTCompatible(variableName).c_str(), &variableTarget);
546  } else {
547  B2ERROR("Couldn't find given " << variableType << " variable named " << variableName <<
548  " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
549  throw std::runtime_error("Couldn't find given " + variableType + " variable named " + variableName +
550  " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
551  }
552  }
553  }
554  }
555 
556  template<class T>
557  void ROOTDataset::setVectorVariableAddress(std::string& variableType, std::vector<std::string>& variableNames,
558  T& variableTargets)
559  {
560  for (unsigned int i = 0; i < variableNames.size(); ++i)
561  ROOTDataset::setScalarVariableAddress(variableType, variableNames[i], variableTargets[i]);
562  }
563 
564  void ROOTDataset::setVectorVariableAddress(std::string& variableType, std::vector<std::string>& variableNames,
565  std::vector<Variable::Manager::VarVariant>& varVariantTargets)
566  {
567  for (unsigned int i = 0; i < variableNames.size(); ++i) {
568  if (std::holds_alternative<double>(varVariantTargets[i]))
569  ROOTDataset::setScalarVariableAddress(variableType, variableNames[i], std::get<double>(varVariantTargets[i]));
570  else if (std::holds_alternative<int>(varVariantTargets[i]))
571  ROOTDataset::setScalarVariableAddress(variableType, variableNames[i], std::get<int>(varVariantTargets[i]));
572  else if (std::holds_alternative<bool>(varVariantTargets[i]))
573  ROOTDataset::setScalarVariableAddress(variableType, variableNames[i], std::get<bool>(varVariantTargets[i]));
574  }
575  }
576 
578  {
579  // Deactivate all branches by default
580  m_tree->SetBranchStatus("*", false);
581  std::string typeName;
582 
583  if (m_general_options.m_weight_variable.empty()) {
584  m_weight = 1;
585  m_weight_double = 1;
586  B2INFO("No weight variable provided. The weight will be set to 1.");
587  }
588 
589  if (m_general_options.m_weight_variable == "__weight__") {
590  if (checkForBranch(m_tree, "__weight__")) {
591  m_tree->SetBranchStatus("__weight__", true);
592  if (m_isFloatInputType)
593  m_tree->SetBranchAddress("__weight__", &m_weight);
594  else
595  m_tree->SetBranchAddress("__weight__", &m_weight_double);
596  } else {
597  B2INFO("Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
598  "weight variable to an empty string if you don't need it.");
599  m_weight = 1;
600  m_weight_double = 1;
601  }
602  } else if (m_isFloatInputType) {
603  typeName = "weight";
605  } else {
606  typeName = "weight";
608  }
609 
610  if (m_target_data_type == Variable::Manager::VariableDataType::c_double) {
611  typeName = "target";
613  } else if (m_target_data_type == Variable::Manager::VariableDataType::c_int) {
614  typeName = "target";
616  } else if (m_target_data_type == Variable::Manager::VariableDataType::c_bool) {
617  typeName = "target";
619  }
620 
621  if (m_isFloatInputType) {
622  typeName = "feature";
624  typeName = "spectator";
626  } else {
627  typeName = "feature";
629  typeName = "spectator";
631  }
632  }
633 
634 
636  {
637  for (unsigned int i = 0; i < m_general_options.m_variables.size(); i++) {
638  auto variable = m_general_options.m_variables[i];
639  std::string branchName = Belle2::MakeROOTCompatible::makeROOTCompatible(variable);
640 
641  if (checkForBranch(m_tree, branchName)) {
642  TBranch* branch = m_tree->GetBranch(branchName.c_str());
643  TLeaf* leaf = branch->GetLeaf(branchName.c_str());
644  std::string type_name = leaf->GetTypeName();
645 
646  // m_isFloatInputType is decided from the first input variable.
647  if (i == 0) {
648  if (type_name == "Float_t")
649  m_isFloatInputType = true;
650  else
651  m_isFloatInputType = false;
652  }
653 
654  if (type_name == "Float_t") {
655  if (m_isFloatInputType)
656  continue;
657  else
658  B2ERROR("There is a mix of float and basf2 variable types (double, int, bool)");
659  } else if (type_name == "Double_t" or type_name == "Int_t" or type_name == "Bool_t") {
660  if (m_isFloatInputType)
661  B2ERROR("There is a mix of float and basf2 variable types (double, int, bool)");
662  else {
663  if (type_name == "Double_t")
664  m_input_variant[i] = 0.0;
665  else if (type_name == "Int_t")
666  m_input_variant[i] = 0;
667  else if (type_name == "Bool_t")
668  m_input_variant[i] = false;
669  }
670  } else {
671  B2FATAL("Unknown root input type: " << type_name);
672  throw std::runtime_error("Unknown root input type: " + type_name);
673  }
674  }
675  }
676 
677  for (unsigned int i = 0; i < m_general_options.m_spectators.size(); i++) {
678  auto variable = m_general_options.m_spectators[i];
679  std::string branchName = Belle2::MakeROOTCompatible::makeROOTCompatible(variable);
680 
681  if (checkForBranch(m_tree, branchName)) {
682  TBranch* branch = m_tree->GetBranch(branchName.c_str());
683  TLeaf* leaf = branch->GetLeaf(branchName.c_str());
684  std::string type_name = leaf->GetTypeName();
685  if (type_name == "Float_t") {
686  if (m_isFloatInputType)
687  continue;
688  else
689  B2ERROR("There is a mix of float and basf2 variable types (double, int, bool)");
690  } else if (type_name == "Double_t" or type_name == "Int_t" or type_name == "Bool_t") {
691  if (m_isFloatInputType)
692  B2ERROR("There is a mix of float and basf2 variable types (double, int, bool)");
693  else {
694  if (type_name == "Double_t")
695  m_spectators_variant[i] = 0.0;
696  else if (type_name == "Int_t")
697  m_spectators_variant[i] = 0;
698  else if (type_name == "Bool_t")
699  m_spectators_variant[i] = false;
700  }
701  } else {
702  B2FATAL("Unknown root input type: " << type_name);
703  throw std::runtime_error("Unknown root input type: " + type_name);
704  }
705  }
706  }
707 
708  }
709 
711  {
713 
714  TBranch* branch = m_tree->GetBranch(branchName.c_str());
715  TLeaf* leaf = branch->GetLeaf(branchName.c_str());
716  std::string target_type_name = leaf->GetTypeName();
717  if (target_type_name == "Double_t")
718  m_target_data_type = Variable::Manager::VariableDataType::c_double;
719  else if (target_type_name == "Int_t")
720  m_target_data_type = Variable::Manager::VariableDataType::c_int;
721  else if (target_type_name == "Bool_t")
722  m_target_data_type = Variable::Manager::VariableDataType::c_bool;
723  else {
724  B2FATAL("Input type " << target_type_name << " for target variable is not supported");
725  throw std::runtime_error("Unsupported target input type: " + target_type_name);
726  }
727  }
728  }
730 }
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
Definition: Dataset.cc:273
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
Definition: Dataset.h:340
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:306
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:296
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Definition: Dataset.cc:277
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
Definition: Dataset.h:339
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual std::vector< bool > getSignals()
Returns all is Signals.
Definition: Dataset.cc:122
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
Definition: Dataset.cc:50
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
Definition: Dataset.cc:86
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:124
virtual std::vector< float > getTargets()
Returns all targets.
Definition: Dataset.cc:110
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
Definition: Dataset.h:122
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Definition: Dataset.h:123
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
Definition: Dataset.cc:26
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:74
virtual std::vector< float > getWeights()
Returns all weights.
Definition: Dataset.cc:98
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:35
bool m_isSignal
Defines if the currently loaded event is signal or background.
Definition: Dataset.h:127
float m_weight
Contains the weight of the currently loaded event.
Definition: Dataset.h:125
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
Definition: Dataset.cc:62
float m_target
Contains the target value of the currently loaded event.
Definition: Dataset.h:126
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
Definition: Options.h:84
int m_signal_class
Signal class which is used as signal in a classification problem.
Definition: Options.h:88
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_treename
Name of the TTree inside the datafile containing the training data.
Definition: Options.h:85
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
std::vector< float > m_weights
weight vector
Definition: Dataset.h:226
std::vector< std::vector< float > > m_matrix
Feature matrix.
Definition: Dataset.h:223
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
Definition: Dataset.h:224
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float >> &input, const std::vector< std::vector< float >> &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
Definition: Dataset.cc:145
std::vector< float > m_targets
target vector
Definition: Dataset.h:225
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
Definition: Dataset.cc:168
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
Definition: Dataset.cc:577
void setVectorVariableAddress(std::string &variableType, std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
Definition: Dataset.cc:557
void setTargetRootInputType()
Determines the data type of the target variable and sets it to m_target_data_type.
Definition: Dataset.cc:710
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
Definition: Dataset.h:371
Variable::Manager::VariableDataType m_target_data_type
Data type of target variable.
Definition: Dataset.h:478
TChain * m_tree
Pointer to the TChain containing the data.
Definition: Dataset.h:472
double m_target_double
Contains the target value of the currently loaded event.
Definition: Dataset.h:480
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
Definition: Dataset.cc:378
int m_target_int
Contains the target value of the currently loaded event.
Definition: Dataset.h:481
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
Definition: Dataset.cc:465
double m_weight_double
Contains the weight of the currently loaded event.
Definition: Dataset.h:477
bool m_isFloatInputType
Defines the expected datatype in the ROOT file.
Definition: Dataset.h:473
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:443
void setScalarVariableAddress(std::string &variableType, std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
Definition: Dataset.cc:535
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
Definition: Dataset.cc:415
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
Definition: Dataset.cc:316
void setRootInputType()
Tries to infer the data-type of a root file and sets m_isDoubleInputType.
Definition: Dataset.cc:635
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
Definition: Dataset.h:366
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
Definition: Dataset.cc:527
std::vector< float > getVectorFromTTree(std::string &variableType, std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
Definition: Dataset.cc:494
virtual ~ROOTDataset()
Virtual destructor.
Definition: Dataset.cc:487
std::vector< Variable::Manager::VarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:476
bool m_target_bool
Contains the target value of the currently loaded event.
Definition: Dataset.h:482
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
Definition: Dataset.h:361
std::vector< Variable::Manager::VarVariant > m_input_variant
Contains all feature values of the currently loaded event.
Definition: Dataset.h:474
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
Definition: Dataset.cc:135
Dataset & m_dataset
Reference to the wrapped dataset.
Definition: Dataset.h:286
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
Definition: Dataset.cc:186
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
Definition: Dataset.h:258
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:259
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
Definition: Dataset.h:281
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:245
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
Definition: Dataset.h:283
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Definition: Dataset.cc:225
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
Definition: Dataset.h:285
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
Definition: Dataset.h:279
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:23