Belle II Software development
Dataset.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/interface/Dataset.h>
10
11#include <framework/utilities/MakeROOTCompatible.h>
12#include <framework/logging/Logger.h>
13#include <framework/io/RootIOUtilities.h>
14
15#include <TLeaf.h>
16
17#include <filesystem>
18
19namespace Belle2 {
24 namespace MVA {
25
26 Dataset::Dataset(const GeneralOptions& general_options) : m_general_options(general_options)
27 {
28 m_input.resize(m_general_options.m_variables.size(), 0);
30 m_target = 0.0;
31 m_weight = 1.0;
32 m_isSignal = false;
33 }
34
36 {
37
38 double signal_weight_sum = 0;
39 double weight_sum = 0;
40 for (unsigned int i = 0; i < getNumberOfEvents(); ++i) {
41 loadEvent(i);
42 weight_sum += m_weight;
43 if (m_isSignal)
44 signal_weight_sum += m_weight;
45 }
46 return signal_weight_sum / weight_sum;
47
48 }
49
50 unsigned int Dataset::getFeatureIndex(const std::string& feature)
51 {
52
53 auto it = std::find(m_general_options.m_variables.begin(), m_general_options.m_variables.end(), feature);
54 if (it == m_general_options.m_variables.end()) {
55 B2ERROR("Unknown feature named " << feature);
56 return 0;
57 }
58 return std::distance(m_general_options.m_variables.begin(), it);
59
60 }
61
62 unsigned int Dataset::getSpectatorIndex(const std::string& spectator)
63 {
64
65 auto it = std::find(m_general_options.m_spectators.begin(), m_general_options.m_spectators.end(), spectator);
66 if (it == m_general_options.m_spectators.end()) {
67 B2ERROR("Unknown spectator named " << spectator);
68 return 0;
69 }
70 return std::distance(m_general_options.m_spectators.begin(), it);
71
72 }
73
74 std::vector<float> Dataset::getFeature(unsigned int iFeature)
75 {
76
77 std::vector<float> result(getNumberOfEvents());
78 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
79 loadEvent(iEvent);
80 result[iEvent] = m_input[iFeature];
81 }
82 return result;
83
84 }
85
86 std::vector<float> Dataset::getSpectator(unsigned int iSpectator)
87 {
88
89 std::vector<float> result(getNumberOfEvents());
90 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
91 loadEvent(iEvent);
92 result[iEvent] = m_spectators[iSpectator];
93 }
94 return result;
95
96 }
97
98 std::vector<float> Dataset::getWeights()
99 {
100
101 std::vector<float> result(getNumberOfEvents());
102 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
103 loadEvent(iEvent);
104 result[iEvent] = m_weight;
105 }
106 return result;
107
108 }
109
110 std::vector<float> Dataset::getTargets()
111 {
112
113 std::vector<float> result(getNumberOfEvents());
114 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
115 loadEvent(iEvent);
116 result[iEvent] = m_target;
117 }
118 return result;
119
120 }
121
122 std::vector<bool> Dataset::getSignals()
123 {
124
125 std::vector<bool> result(getNumberOfEvents());
126 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
127 loadEvent(iEvent);
128 result[iEvent] = m_isSignal;
129 }
130 return result;
131
132 }
133
134
135 SingleDataset::SingleDataset(const GeneralOptions& general_options, const std::vector<float>& input, float target,
136 const std::vector<float>& spectators) : Dataset(general_options)
137 {
138 m_input = input;
139 m_spectators = spectators;
140 m_target = target;
141 m_weight = 1.0;
142 m_isSignal = std::lround(target) == m_general_options.m_signal_class;
143 }
144
145 MultiDataset::MultiDataset(const GeneralOptions& general_options, const std::vector<std::vector<float>>& input,
146 const std::vector<std::vector<float>>& spectators,
147 const std::vector<float>& targets, const std::vector<float>& weights) : Dataset(general_options), m_matrix(input),
148 m_spectator_matrix(spectators),
149 m_targets(targets), m_weights(weights)
150 {
151
152 if (m_targets.size() > 0 and m_matrix.size() != m_targets.size()) {
153 B2ERROR("Feature matrix and target vector need same number of elements in MultiDataset, got " << m_targets.size() << " and " <<
154 m_matrix.size());
155 }
156 if (m_weights.size() > 0 and m_matrix.size() != m_weights.size()) {
157 B2ERROR("Feature matrix and weight vector need same number of elements in MultiDataset, got " << m_weights.size() << " and " <<
158 m_matrix.size());
159 }
160 if (m_spectator_matrix.size() > 0 and m_matrix.size() != m_spectator_matrix.size()) {
161 B2ERROR("Feature matrix and spectator matrix need same number of elements in MultiDataset, got " << m_spectator_matrix.size() <<
162 " and " <<
163 m_matrix.size());
164 }
165 }
166
167
168 void MultiDataset::loadEvent(unsigned int iEvent)
169 {
170 m_input = m_matrix[iEvent];
171
172 if (m_spectator_matrix.size() > 0) {
174 }
175
176 if (m_targets.size() > 0) {
177 m_target = m_targets[iEvent];
179 }
180
181 if (m_weights.size() > 0)
182 m_weight = m_weights[iEvent];
183
184 }
185
186 SubDataset::SubDataset(const GeneralOptions& general_options, const std::vector<bool>& events,
187 Dataset& dataset) : Dataset(general_options), m_dataset(dataset)
188 {
189
190 for (auto& v : m_general_options.m_variables) {
191 auto it = std::find(m_dataset.m_general_options.m_variables.begin(), m_dataset.m_general_options.m_variables.end(), v);
192 if (it == m_dataset.m_general_options.m_variables.end()) {
193 B2ERROR("Couldn't find variable " << v << " in GeneralOptions");
194 throw std::runtime_error("Couldn't find variable " + v + " in GeneralOptions");
195 }
197 }
198
199 for (auto& v : m_general_options.m_spectators) {
201 if (it == m_dataset.m_general_options.m_spectators.end()) {
202 B2ERROR("Couldn't find spectator " << v << " in GeneralOptions");
203 throw std::runtime_error("Couldn't find spectator " + v + " in GeneralOptions");
204 }
206 }
207
208 if (events.size() > 0)
209 m_use_event_indices = true;
210
212 m_event_indices.resize(dataset.getNumberOfEvents());
213 unsigned int n_events = 0;
214 for (unsigned int iEvent = 0; iEvent < dataset.getNumberOfEvents(); ++iEvent) {
215 if (events.size() == 0 or events[iEvent]) {
216 m_event_indices[n_events] = iEvent;
217 n_events++;
218 }
219 }
220 m_event_indices.resize(n_events);
221 }
222
223 }
224
225 void SubDataset::loadEvent(unsigned int iEvent)
226 {
227 unsigned int index = iEvent;
229 index = m_event_indices[iEvent];
230 m_dataset.loadEvent(index);
234
235 for (unsigned int iFeature = 0; iFeature < m_input.size(); ++iFeature) {
236 m_input[iFeature] = m_dataset.m_input[m_feature_indices[iFeature]];
237 }
238
239 for (unsigned int iSpectator = 0; iSpectator < m_spectators.size(); ++iSpectator) {
240 m_spectators[iSpectator] = m_dataset.m_spectators[m_spectator_indices[iSpectator]];
241 }
242
243 }
244
245 std::vector<float> SubDataset::getFeature(unsigned int iFeature)
246 {
247
248 auto v = m_dataset.getFeature(m_feature_indices[iFeature]);
249 if (not m_use_event_indices)
250 return v;
251 std::vector<float> result(m_event_indices.size());
252 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
253 result[iEvent] = v[m_event_indices[iEvent]];
254 }
255 return result;
256
257 }
258
259 std::vector<float> SubDataset::getSpectator(unsigned int iSpectator)
260 {
261
262 auto v = m_dataset.getSpectator(m_spectator_indices[iSpectator]);
263 if (not m_use_event_indices)
264 return v;
265 std::vector<float> result(m_event_indices.size());
266 for (unsigned int iEvent = 0; iEvent < getNumberOfEvents(); ++iEvent) {
267 result[iEvent] = v[m_event_indices[iEvent]];
268 }
269 return result;
270
271 }
272
273 CombinedDataset::CombinedDataset(const GeneralOptions& general_options, Dataset& signal_dataset,
274 Dataset& background_dataset) : Dataset(general_options), m_signal_dataset(signal_dataset),
275 m_background_dataset(background_dataset) { }
276
277 void CombinedDataset::loadEvent(unsigned int iEvent)
278 {
279 if (iEvent < m_signal_dataset.getNumberOfEvents()) {
281 m_target = 1.0;
282 m_isSignal = true;
286 } else {
288 m_target = 0.0;
289 m_isSignal = false;
293 }
294 }
295
296 std::vector<float> CombinedDataset::getFeature(unsigned int iFeature)
297 {
298
299 auto s = m_signal_dataset.getFeature(iFeature);
300 auto b = m_background_dataset.getFeature(iFeature);
301 s.insert(s.end(), b.begin(), b.end());
302 return s;
303
304 }
305
306 std::vector<float> CombinedDataset::getSpectator(unsigned int iSpectator)
307 {
308
309 auto s = m_signal_dataset.getSpectator(iSpectator);
310 auto b = m_background_dataset.getSpectator(iSpectator);
311 s.insert(s.end(), b.begin(), b.end());
312 return s;
313
314 }
315
316 ROOTDataset::ROOTDataset(const GeneralOptions& general_options) : Dataset(general_options)
317 {
320 m_weight_variant = 1.0f;
321 m_target_variant = 0.0f;
322
323 for (const auto& variable : general_options.m_variables)
324 for (const auto& spectator : general_options.m_spectators)
325 if (variable == spectator or variable == general_options.m_target_variable or spectator == general_options.m_target_variable) {
326 B2ERROR("Interface doesn't support variable more then one time in either spectators, variables or target variable");
327 throw std::runtime_error("Interface doesn't support variable more then one time in either spectators, variables or target variable");
328 }
329
330 std::vector<std::string> filenames;
331 for (const auto& filename : m_general_options.m_datafiles) {
332 if (std::filesystem::exists(filename)) {
333 filenames.push_back(filename);
334 } else {
336 filenames.insert(filenames.end(), temp.begin(), temp.end());
337 }
338 }
339 if (filenames.empty()) {
340 B2ERROR("Found no valid filenames in GeneralOptions");
341 throw std::runtime_error("Found no valid filenames in GeneralOptions");
342 }
343
344 //Open TFile
345 TDirectory* dir = gDirectory;
346 for (const auto& filename : filenames) {
347 TFile* f = TFile::Open(filename.c_str(), "READ");
348 if (!f or f->IsZombie() or not f->IsOpen()) {
349 B2ERROR("Error during open of ROOT file named " << filename);
350 throw std::runtime_error("Error during open of ROOT file named " + filename);
351 }
352 delete f;
353 }
354 dir->cd();
355
356 m_tree = new TChain(m_general_options.m_treename.c_str());
357 for (const auto& filename : filenames) {
358 //nentries = -1 forces AddFile() to read headers
359 if (!m_tree->AddFile(filename.c_str(), -1)) {
360 B2ERROR("Error during open of ROOT file named " << filename << " cannot retrieve tree named " <<
362 throw std::runtime_error("Error during open of ROOT file named " + filename + " cannot retrieve tree named " +
364 }
365 }
368 }
369
370
372 {
373 if (std::holds_alternative<double>(variant))
374 return static_cast<float>(std::get<double>(variant));
375 else if (std::holds_alternative<float>(variant))
376 return std::get<float>(variant);
377 else if (std::holds_alternative<int>(variant))
378 return static_cast<float>(std::get<int>(variant));
379 else if (std::holds_alternative<bool>(variant))
380 return static_cast<float>(std::get<bool>(variant));
381 else {
382 B2FATAL("Unsupported variable type");
383 }
384 }
385
386 void ROOTDataset::loadEvent(unsigned int event)
387 {
388 if (m_tree->GetEntry(event, 0) == 0) {
389 B2ERROR("Error during loading entry from chain");
390 }
391
392 for (unsigned int i = 0; i < m_input_variant.size(); i++) {
394 }
395 for (unsigned int i = 0; i < m_spectators_variant.size(); i++) {
397 }
401 }
402
403 std::vector<float> ROOTDataset::getWeights()
404 {
406 if (branchName.empty()) {
407 B2INFO("No TBranch name given for weights. Using 1s as default weights.");
408 int nentries = getNumberOfEvents();
409 std::vector<float> values(nentries, 1.);
410 return values;
411 }
412 if (branchName == "__weight__") {
413 if (!checkForBranch(m_tree, "__weight__")) {
414 B2INFO("No default weight branch with name __weight__ found. Using 1s as default weights.");
415 int nentries = getNumberOfEvents();
416 std::vector<float> values(nentries, 1.);
417 return values;
418 }
419 }
420 std::string typeLabel = "weights";
421 return getVectorFromTTreeVariant(typeLabel, branchName, m_weight_variant);
422 }
423
424 std::vector<float> ROOTDataset::getFeature(unsigned int iFeature)
425 {
426 if (iFeature >= getNumberOfFeatures()) {
427 B2ERROR("Feature index " << iFeature << " is out of bounds of given number of features: "
429 }
431 std::string typeLabel = "features";
432 return getVectorFromTTreeVariant(typeLabel, branchName, m_input_variant[iFeature]);
433 }
434
435 std::vector<float> ROOTDataset::getSpectator(unsigned int iSpectator)
436 {
437 if (iSpectator >= getNumberOfSpectators()) {
438 B2ERROR("Spectator index " << iSpectator << " is out of bounds of given number of spectators: "
440 }
441
443 std::string typeLabel = "spectators";
444 return getVectorFromTTreeVariant(typeLabel, branchName, m_spectators_variant[iSpectator]);
445 }
446
448 {
449 delete m_tree;
450 m_tree = nullptr;
451 }
452
453 std::vector<float> ROOTDataset::getVectorFromTTreeVariant(const std::string& variableType, const std::string& branchName,
454 RootDatasetVarVariant& memberVariableTarget)
455 {
456 if (std::holds_alternative<double>(memberVariableTarget))
457 return getVectorFromTTree(variableType, branchName, std::get<double>(memberVariableTarget));
458 else if (std::holds_alternative<float>(memberVariableTarget))
459 return getVectorFromTTree(variableType, branchName, std::get<float>(memberVariableTarget));
460 else if (std::holds_alternative<int>(memberVariableTarget))
461 return getVectorFromTTree(variableType, branchName, std::get<int>(memberVariableTarget));
462 else if (std::holds_alternative<bool>(memberVariableTarget))
463 return getVectorFromTTree(variableType, branchName, std::get<bool>(memberVariableTarget));
464 else
465 B2FATAL("Input type of " << variableType << " variable " << branchName << " is not supported");
466 }
467
468 template<class T>
469 std::vector<float> ROOTDataset::getVectorFromTTree(const std::string& variableType, const std::string& branchName,
470 T& memberVariableTarget)
471 {
472 int nentries = getNumberOfEvents();
473 std::vector<float> values(nentries);
474
475 // Float or Double to be filled
476 T object;
477 auto currentTreeNumber = m_tree->GetTreeNumber();
478 TBranch* branch = m_tree->GetBranch(branchName.c_str());
479 if (not branch) {
480 B2ERROR("TBranch for " + variableType + " named '" << branchName.c_str() << "' does not exist!");
481 }
482 branch->SetAddress(&object);
483 for (int i = 0; i < nentries; ++i) {
484 auto entry = m_tree->LoadTree(i);
485 if (entry < 0) {
486 B2ERROR("Error during loading root tree from chain, error code: " << entry);
487 }
488 // if current tree changed we have to update the branch
489 if (currentTreeNumber != m_tree->GetTreeNumber()) {
490 currentTreeNumber = m_tree->GetTreeNumber();
491 branch = m_tree->GetBranch(branchName.c_str());
492 branch->SetAddress(&object);
493 }
494 branch->GetEntry(entry);
495 values[i] = object;
496 }
497 // Reset branch to correct input address, just to be sure
498 m_tree->SetBranchAddress(branchName.c_str(), &memberVariableTarget);
499 return values;
500 }
501
502 bool ROOTDataset::checkForBranch(TTree* tree, const std::string& branchname) const
503 {
504 auto branch = tree->GetListOfBranches()->FindObject(branchname.c_str());
505 return branch != nullptr;
506
507 }
508
509 template<class T>
510 void ROOTDataset::setScalarVariableAddress(const std::string& variableType, const std::string& variableName,
511 T& variableTarget)
512 {
513 if (not variableName.empty()) {
514 if (checkForBranch(m_tree, variableName)) {
515 m_tree->SetBranchStatus(variableName.c_str(), true);
516 m_tree->SetBranchAddress(variableName.c_str(), &variableTarget);
517 } else {
519 m_tree->SetBranchStatus(Belle2::MakeROOTCompatible::makeROOTCompatible(variableName).c_str(), true);
520 m_tree->SetBranchAddress(Belle2::MakeROOTCompatible::makeROOTCompatible(variableName).c_str(), &variableTarget);
521 } else {
522 B2ERROR("Couldn't find given " << variableType << " variable named " << variableName <<
523 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
524 throw std::runtime_error("Couldn't find given " + variableType + " variable named " + variableName +
525 " (I tried also using MakeROOTCompatible::makeROOTCompatible)");
526 }
527 }
528 }
529 }
530
531 void ROOTDataset::setScalarVariableAddressVariant(const std::string& variableType, const std::string& variableName,
532 RootDatasetVarVariant& varVariantTarget)
533 {
534 if (std::holds_alternative<double>(varVariantTarget))
535 setScalarVariableAddress(variableType, variableName, std::get<double>(varVariantTarget));
536 else if (std::holds_alternative<float>(varVariantTarget))
537 setScalarVariableAddress(variableType, variableName, std::get<float>(varVariantTarget));
538 else if (std::holds_alternative<int>(varVariantTarget))
539 setScalarVariableAddress(variableType, variableName, std::get<int>(varVariantTarget));
540 else if (std::holds_alternative<bool>(varVariantTarget))
541 setScalarVariableAddress(variableType, variableName, std::get<bool>(varVariantTarget));
542 else
543 B2FATAL("Variable type for branch " << variableName << " not supported!");
544 }
545
546 template<class T>
547 void ROOTDataset::setVectorVariableAddress(const std::string& variableType, const std::vector<std::string>& variableNames,
548 T& variableTargets)
549 {
550 for (unsigned int i = 0; i < variableNames.size(); ++i)
551 ROOTDataset::setScalarVariableAddress(variableType, variableNames[i], variableTargets[i]);
552 }
553
554
555 void ROOTDataset::setVectorVariableAddressVariant(const std::string& variableType, const std::vector<std::string>& variableNames,
556 std::vector<RootDatasetVarVariant>& varVariantTargets)
557 {
558 for (unsigned int i = 0; i < variableNames.size(); ++i) {
559 ROOTDataset::setScalarVariableAddressVariant(variableType, variableNames[i], varVariantTargets[i]);
560 }
561 }
562
564 {
565 // Deactivate all branches by default
566 m_tree->SetBranchStatus("*", false);
567
569 if (m_general_options.m_weight_variable == "__weight__") {
570 if (checkForBranch(m_tree, "__weight__")) {
571 m_tree->SetBranchStatus("__weight__", true);
572 std::string typeLabel_weight = "weight";
573 std::string weight_string = "__weight__";
574 setScalarVariableAddressVariant(typeLabel_weight, weight_string, m_weight_variant);
575 } else {
576 m_weight_variant = 1.0f;
577 }
578 } else {
579 std::string typeLabel_weight = "weight";
581 }
582 }
583
584 std::string typeLabel_target = "target";
586 std::string typeLabel_feature = "feature";
588 std::string typeLabel_spectator = "spectator";
590 }
591
592
593 void ROOTDataset::initialiseVarVariantType(const std::string type, RootDatasetVarVariant& varVariantTarget)
594 {
595 if (type == "Double_t")
596 varVariantTarget = 0.0;
597 else if (type == "Float_t")
598 varVariantTarget = 0.0f;
599 else if (type == "Int_t")
600 varVariantTarget = 0;
601 else if (type == "Bool_t")
602 varVariantTarget = false;
603 else {
604 B2FATAL("Unknown root input type: " << type);
605 throw std::runtime_error("Unknown root input type: " + type);
606 }
607 }
608
609
610 void ROOTDataset::initialiseVarVariantForBranch(const std::string branch_name, RootDatasetVarVariant& varVariantTarget)
611 {
612 std::string compatible_branch_name = Belle2::MakeROOTCompatible::makeROOTCompatible(branch_name);
613 // try the branch as is first then fall back to root safe name.
614 if (checkForBranch(m_tree, branch_name.c_str())) {
615 TBranch* branch = m_tree->GetBranch(branch_name.c_str());
616 TLeaf* leaf = branch->GetLeaf(branch_name.c_str());
617 std::string type_name = leaf->GetTypeName();
618 initialiseVarVariantType(type_name, varVariantTarget);
619 } else if (checkForBranch(m_tree, compatible_branch_name)) {
620 TBranch* branch = m_tree->GetBranch(compatible_branch_name.c_str());
621 TLeaf* leaf = branch->GetLeaf(compatible_branch_name.c_str());
622 std::string type_name = leaf->GetTypeName();
623 initialiseVarVariantType(type_name, varVariantTarget);
624 }
625 }
626
628 {
629 // set target variable
631
632 // set feature variables
633 for (unsigned int i = 0; i < m_general_options.m_variables.size(); i++) {
634 auto variable = m_general_options.m_variables[i];
636 }
637
638 // set spectator variables
639 for (unsigned int i = 0; i < m_general_options.m_spectators.size(); i++) {
640 auto variable = m_general_options.m_spectators[i];
642 }
643
644 // set weight variable - bit more tricky as we allow it to not be set or to not be present.
646 m_weight_variant = 1.0f;
647 B2INFO("No weight variable provided. The weight will be set to 1.");
648 } else {
649 if (m_general_options.m_weight_variable == "__weight__") {
650 if (checkForBranch(m_tree, "__weight__")) {
651 m_tree->SetBranchStatus("__weight__", true);
653 } else {
654 B2INFO("Couldn't find default weight feature named __weight__, all weights will be 1. Consider setting the "
655 "weight variable to an empty string if you don't need it.");
656 m_weight_variant = 1.0f;
657 }
658 } else {
660 }
661 }
662 }
663
664 }
666}
CombinedDataset(const GeneralOptions &general_options, Dataset &signal_dataset, Dataset &background_dataset)
Constructs a new CombinedDataset holding a reference to the wrapped Datasets.
Definition: Dataset.cc:273
Dataset & m_background_dataset
Reference to the wrapped dataset containing background events.
Definition: Dataset.h:340
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:306
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:296
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Definition: Dataset.cc:277
Dataset & m_signal_dataset
Reference to the wrapped dataset containing signal events.
Definition: Dataset.h:339
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual std::vector< bool > getSignals()
Returns all is Signals.
Definition: Dataset.cc:122
virtual unsigned int getFeatureIndex(const std::string &feature)
Return index of feature with the given name.
Definition: Dataset.cc:50
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
Definition: Dataset.cc:86
std::vector< float > m_spectators
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:124
virtual std::vector< float > getTargets()
Returns all targets.
Definition: Dataset.cc:110
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
GeneralOptions m_general_options
GeneralOptions passed to this dataset.
Definition: Dataset.h:122
std::vector< float > m_input
Contains all feature values of the currently loaded event.
Definition: Dataset.h:123
Dataset(const GeneralOptions &general_options)
Constructs a new dataset given the general options.
Definition: Dataset.cc:26
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:74
virtual std::vector< float > getWeights()
Returns all weights.
Definition: Dataset.cc:98
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:35
bool m_isSignal
Defines if the currently loaded event is signal or background.
Definition: Dataset.h:127
float m_weight
Contains the weight of the currently loaded event.
Definition: Dataset.h:125
virtual unsigned int getSpectatorIndex(const std::string &spectator)
Return index of spectator with the given name.
Definition: Dataset.cc:62
float m_target
Contains the target value of the currently loaded event.
Definition: Dataset.h:126
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
Definition: Options.h:84
int m_signal_class
Signal class which is used as signal in a classification problem.
Definition: Options.h:88
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_treename
Name of the TTree inside the datafile containing the training data.
Definition: Options.h:85
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
std::vector< float > m_weights
weight vector
Definition: Dataset.h:226
std::vector< std::vector< float > > m_matrix
Feature matrix.
Definition: Dataset.h:223
std::vector< std::vector< float > > m_spectator_matrix
Spectator matrix.
Definition: Dataset.h:224
std::vector< float > m_targets
target vector
Definition: Dataset.h:225
virtual void loadEvent(unsigned int iEvent) override
Does nothing in the case of a single dataset, because the only event is already loaded.
Definition: Dataset.cc:168
MultiDataset(const GeneralOptions &general_options, const std::vector< std::vector< float > > &input, const std::vector< std::vector< float > > &spectators, const std::vector< float > &targets={}, const std::vector< float > &weights={})
Constructs a new MultiDataset.
Definition: Dataset.cc:145
void setScalarVariableAddress(const std::string &variableType, const std::string &variableName, T &variableTarget)
sets the branch address for a scalar variable to a given target
Definition: Dataset.cc:510
void setBranchAddresses()
Sets the branch addresses of all features, weight and target again.
Definition: Dataset.cc:563
void setScalarVariableAddressVariant(const std::string &variableType, const std::string &variableName, RootDatasetVarVariant &variableTarget)
sets the branch address for a scalar variable to a given target
Definition: Dataset.cc:531
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
Definition: Dataset.h:371
void initialiseVarVariantForBranch(const std::string, RootDatasetVarVariant &)
Infers the type (double,float,int,bool) from the TTree and initialises the VarVariant with the correc...
Definition: Dataset.cc:610
TChain * m_tree
Pointer to the TChain containing the data.
Definition: Dataset.h:408
virtual void loadEvent(unsigned int event) override
Load the event number iEvent from the TTree.
Definition: Dataset.cc:386
void setVectorVariableAddressVariant(const std::string &variableType, const std::vector< std::string > &variableName, std::vector< RootDatasetVarVariant > &varVariantTargets)
sets the branch address for a vector of VarVariant to a given target
Definition: Dataset.cc:555
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float>
Definition: Dataset.cc:435
std::vector< float > getVectorFromTTree(const std::string &variableType, const std::string &branchName, T &memberVariableTarget)
Returns all values for a specified variableType and branchName.
Definition: Dataset.cc:469
void initialiseVarVariantType(const std::string, RootDatasetVarVariant &)
Initialises the VarVariant.
Definition: Dataset.cc:593
std::vector< RootDatasetVarVariant > m_spectators_variant
Contains all spectators values of the currently loaded event.
Definition: Dataset.h:411
RootDatasetVarVariant m_target_variant
Contains the target value of the currently loaded event.
Definition: Dataset.h:413
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:424
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
Definition: Dataset.cc:403
std::vector< RootDatasetVarVariant > m_input_variant
Contains all feature values of the currently loaded event.
Definition: Dataset.h:409
ROOTDataset(const GeneralOptions &_general_options)
Creates a new ROOTDataset.
Definition: Dataset.cc:316
void setRootInputType()
Tries to infer the data-type of the spectator and feature variables in a root file.
Definition: Dataset.cc:627
virtual unsigned int getNumberOfSpectators() const override
Returns the number of features in this dataset.
Definition: Dataset.h:366
bool checkForBranch(TTree *, const std::string &) const
Checks if the given branchname exists in the TTree.
Definition: Dataset.cc:502
void setVectorVariableAddress(const std::string &variableType, const std::vector< std::string > &variableName, T &variableTargets)
sets the branch address for a vector variable to a given target
Definition: Dataset.cc:547
virtual ~ROOTDataset()
Virtual destructor.
Definition: Dataset.cc:447
float castVarVariantToFloat(RootDatasetVarVariant &) const
Casts a VarVariant which can contain <double,int,bool,float> to float.
Definition: Dataset.cc:371
virtual unsigned int getNumberOfFeatures() const override
Returns the number of features in this dataset.
Definition: Dataset.h:361
std::variant< double, float, int, bool > RootDatasetVarVariant
Typedef for variable types supported by the mva ROOTDataset, can be one of double,...
Definition: Dataset.h:406
std::vector< float > getVectorFromTTreeVariant(const std::string &variableType, const std::string &branchName, RootDatasetVarVariant &memberVariableTarget)
Returns all values for a specified variableType and branchName.
Definition: Dataset.cc:453
RootDatasetVarVariant m_weight_variant
Contains the weight of the currently loaded event.
Definition: Dataset.h:412
SingleDataset(const GeneralOptions &general_options, const std::vector< float > &input, float target=1.0, const std::vector< float > &spectators=std::vector< float >())
Constructs a new SingleDataset.
Definition: Dataset.cc:135
Dataset & m_dataset
Reference to the wrapped dataset.
Definition: Dataset.h:286
SubDataset(const GeneralOptions &general_options, const std::vector< bool > &events, Dataset &dataset)
Constructs a new SubDataset holding a reference to the wrapped Dataset.
Definition: Dataset.cc:186
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in the wrapped dataset.
Definition: Dataset.h:258
virtual std::vector< float > getSpectator(unsigned int iSpectator) override
Returns all values of one spectator in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:259
std::vector< unsigned int > m_feature_indices
Mapping from the position of a feature in the given subset to its position in the wrapped dataset.
Definition: Dataset.h:281
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float> of the wrapped dataset.
Definition: Dataset.cc:245
std::vector< unsigned int > m_spectator_indices
Mapping from the position of a spectator in the given subset to its position in the wrapped dataset.
Definition: Dataset.h:283
virtual void loadEvent(unsigned int iEvent) override
Load the event number iEvent from the wrapped dataset.
Definition: Dataset.cc:225
std::vector< unsigned int > m_event_indices
Mapping from the position of a event in the given subset to its position in the wrapped dataset.
Definition: Dataset.h:285
bool m_use_event_indices
Use only a subset of the wrapped dataset events.
Definition: Dataset.h:279
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
std::vector< std::string > expandWordExpansions(const std::vector< std::string > &filenames)
Performs wildcard expansion using wordexp(), returns matches.
Abstract base class for different kinds of events.