Belle II Software  release-08-01-10
test_Dataset.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/interface/Dataset.h>
10 #include <framework/utilities/TestHelpers.h>
11 
12 #include <gtest/gtest.h>
13 
14 #include <filesystem>
15 #include <fstream>
16 #include <numeric>
17 
18 using namespace Belle2;
19 
20 namespace {
21 
22  class TestDataset : public MVA::Dataset {
23  public:
24  explicit TestDataset(MVA::GeneralOptions& general_options) : MVA::Dataset(general_options)
25  {
26  m_input = {1.0, 2.0, 3.0, 4.0, 5.0};
27  m_target = 3.0;
28  m_isSignal = true;
29  m_weight = -3.0;
30  }
31 
32  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 5; }
33  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 2; }
34  [[nodiscard]] unsigned int getNumberOfEvents() const override { return 20; }
35  void loadEvent(unsigned int iEvent) override
36  {
37  auto f = static_cast<float>(iEvent);
38  m_input = {f + 1, f + 2, f + 3, f + 4, f + 5};
39  m_spectators = {f + 6, f + 7};
40  };
41  float getSignalFraction() override { return 0.1; };
42  std::vector<float> getFeature(unsigned int iFeature) override
43  {
44  std::vector<float> a(20, 0.0);
45  std::iota(a.begin(), a.end(), iFeature + 1);
46  return a;
47  }
48  std::vector<float> getSpectator(unsigned int iSpectator) override
49  {
50  std::vector<float> a(20, 0.0);
51  std::iota(a.begin(), a.end(), iSpectator + 6);
52  return a;
53  }
54 
55  };
56 
57  TEST(DatasetTest, SingleDataset)
58  {
59 
60  MVA::GeneralOptions general_options;
61  general_options.m_variables = {"a", "b", "c"};
62  general_options.m_spectators = {"e", "f"};
63  general_options.m_signal_class = 2;
64  std::vector<float> input = {1.0, 2.0, 3.0};
65  std::vector<float> spectators = {4.0, 5.0};
66 
67  MVA::SingleDataset x(general_options, input, 2.0, spectators);
68 
69  EXPECT_EQ(x.getNumberOfFeatures(), 3);
70  EXPECT_EQ(x.getNumberOfSpectators(), 2);
71  EXPECT_EQ(x.getNumberOfEvents(), 1);
72 
73  EXPECT_EQ(x.getFeatureIndex("a"), 0);
74  EXPECT_EQ(x.getFeatureIndex("b"), 1);
75  EXPECT_EQ(x.getFeatureIndex("c"), 2);
76  EXPECT_B2ERROR(x.getFeatureIndex("bla"));
77 
78  EXPECT_EQ(x.getSpectatorIndex("e"), 0);
79  EXPECT_EQ(x.getSpectatorIndex("f"), 1);
80  EXPECT_B2ERROR(x.getSpectatorIndex("bla"));
81 
82  // Should just work
83  x.loadEvent(0);
84 
85  EXPECT_EQ(x.m_input.size(), 3);
86  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
87  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
88  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
89 
90  EXPECT_EQ(x.m_spectators.size(), 2);
91  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.0);
92  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.0);
93 
94  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
95  EXPECT_FLOAT_EQ(x.m_target, 2.0);
96  EXPECT_EQ(x.m_isSignal, true);
97 
98  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1.0);
99 
100  auto feature = x.getFeature(1);
101  EXPECT_EQ(feature.size(), 1);
102  EXPECT_FLOAT_EQ(feature[0], 2.0);
103 
104  auto spectator = x.getSpectator(1);
105  EXPECT_EQ(spectator.size(), 1);
106  EXPECT_FLOAT_EQ(spectator[0], 5.0);
107 
108  // Same result for mother class implementation
109  feature = x.Dataset::getFeature(1);
110  EXPECT_EQ(feature.size(), 1);
111  EXPECT_FLOAT_EQ(feature[0], 2.0);
112 
113  }
114 
115  TEST(DatasetTest, MultiDataset)
116  {
117 
118  MVA::GeneralOptions general_options;
119  general_options.m_variables = {"a", "b", "c"};
120  general_options.m_spectators = {"e", "f"};
121  general_options.m_signal_class = 2;
122  std::vector<std::vector<float>> matrix = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}};
123  std::vector<std::vector<float>> spectator_matrix = {{12.0, 13.0}, {15.0, 16.0}, {18.0, 19.0}};
124  std::vector<float> targets = {2.0, 0.0, 2.0};
125  std::vector<float> weights = {1.0, 2.0, 3.0};
126 
127  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, {1.0}, weights));
128 
129  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, targets, {1.0}));
130 
131  MVA::MultiDataset x(general_options, matrix, spectator_matrix, targets, weights);
132 
133  EXPECT_EQ(x.getNumberOfFeatures(), 3);
134  EXPECT_EQ(x.getNumberOfEvents(), 3);
135 
136  // Should just work
137  x.loadEvent(0);
138 
139  EXPECT_EQ(x.m_input.size(), 3);
140  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
141  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
142  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
143 
144  EXPECT_EQ(x.m_spectators.size(), 2);
145  EXPECT_FLOAT_EQ(x.m_spectators[0], 12.0);
146  EXPECT_FLOAT_EQ(x.m_spectators[1], 13.0);
147 
148  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
149  EXPECT_FLOAT_EQ(x.m_target, 2.0);
150  EXPECT_EQ(x.m_isSignal, true);
151 
152  // Should just work
153  x.loadEvent(1);
154 
155  EXPECT_EQ(x.m_input.size(), 3);
156  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
157  EXPECT_FLOAT_EQ(x.m_input[1], 5.0);
158  EXPECT_FLOAT_EQ(x.m_input[2], 6.0);
159 
160  EXPECT_EQ(x.m_spectators.size(), 2);
161  EXPECT_FLOAT_EQ(x.m_spectators[0], 15.0);
162  EXPECT_FLOAT_EQ(x.m_spectators[1], 16.0);
163 
164  EXPECT_FLOAT_EQ(x.m_weight, 2.0);
165  EXPECT_FLOAT_EQ(x.m_target, 0.0);
166  EXPECT_EQ(x.m_isSignal, false);
167 
168  // Should just work
169  x.loadEvent(2);
170 
171  EXPECT_EQ(x.m_input.size(), 3);
172  EXPECT_FLOAT_EQ(x.m_input[0], 7.0);
173  EXPECT_FLOAT_EQ(x.m_input[1], 8.0);
174  EXPECT_FLOAT_EQ(x.m_input[2], 9.0);
175 
176  EXPECT_EQ(x.m_spectators.size(), 2);
177  EXPECT_FLOAT_EQ(x.m_spectators[0], 18.0);
178  EXPECT_FLOAT_EQ(x.m_spectators[1], 19.0);
179 
180  EXPECT_FLOAT_EQ(x.m_weight, 3.0);
181  EXPECT_FLOAT_EQ(x.m_target, 2.0);
182  EXPECT_EQ(x.m_isSignal, true);
183 
184  EXPECT_FLOAT_EQ(x.getSignalFraction(), 4.0 / 6.0);
185 
186  auto feature = x.getFeature(1);
187  EXPECT_EQ(feature.size(), 3);
188  EXPECT_FLOAT_EQ(feature[0], 2.0);
189  EXPECT_FLOAT_EQ(feature[1], 5.0);
190  EXPECT_FLOAT_EQ(feature[2], 8.0);
191 
192  auto spectator = x.getSpectator(1);
193  EXPECT_EQ(spectator.size(), 3);
194  EXPECT_FLOAT_EQ(spectator[0], 13.0);
195  EXPECT_FLOAT_EQ(spectator[1], 16.0);
196  EXPECT_FLOAT_EQ(spectator[2], 19.0);
197 
198 
199  }
200 
201  TEST(DatasetTest, SubDataset)
202  {
203 
204  MVA::GeneralOptions general_options;
205  general_options.m_signal_class = 3;
206  general_options.m_variables = {"a", "b", "c", "d", "e"};
207  general_options.m_spectators = {"f", "g"};
208  TestDataset test_dataset(general_options);
209 
210  general_options.m_variables = {"a", "d", "e"};
211  general_options.m_spectators = {"g"};
212  std::vector<bool> events = {true, false, true, false, true, false, true, false, true, false,
213  true, false, true, false, true, false, true, false, true, false
214  };
215  MVA::SubDataset x(general_options, events, test_dataset);
216 
217  EXPECT_EQ(x.getNumberOfFeatures(), 3);
218  EXPECT_EQ(x.getNumberOfEvents(), 10);
219 
220  // Should just work
221  x.loadEvent(0);
222 
223  EXPECT_EQ(x.m_input.size(), 3);
224  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
225  EXPECT_FLOAT_EQ(x.m_input[1], 4.0);
226  EXPECT_FLOAT_EQ(x.m_input[2], 5.0);
227 
228  EXPECT_EQ(x.m_spectators.size(), 1);
229  EXPECT_FLOAT_EQ(x.m_spectators[0], 7.0);
230 
231  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
232  EXPECT_FLOAT_EQ(x.m_target, 3.0);
233  EXPECT_EQ(x.m_isSignal, true);
234 
235  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1);
236 
237  auto feature = x.getFeature(1);
238  EXPECT_EQ(feature.size(), 10);
239  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
240  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
241  };
242 
243  auto spectator = x.getSpectator(0);
244  EXPECT_EQ(spectator.size(), 10);
245  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
246  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
247  };
248 
249  // Same result for mother class implementation
250  feature = x.Dataset::getFeature(1);
251  EXPECT_EQ(feature.size(), 10);
252  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
253  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
254  };
255 
256  spectator = x.Dataset::getSpectator(0);
257  EXPECT_EQ(spectator.size(), 10);
258  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
259  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
260  };
261 
262  // Test without event indices
263  MVA::SubDataset y(general_options, {}, test_dataset);
264  feature = y.getFeature(1);
265  EXPECT_EQ(feature.size(), 20);
266  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
267  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
268  };
269 
270  spectator = y.getSpectator(0);
271  EXPECT_EQ(spectator.size(), 20);
272  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
273  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
274  };
275 
276  // Same result for mother class implementation
277  feature = y.Dataset::getFeature(1);
278  EXPECT_EQ(feature.size(), 20);
279  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
280  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
281  };
282 
283  spectator = y.Dataset::getSpectator(0);
284  EXPECT_EQ(spectator.size(), 20);
285  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
286  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
287  };
288 
289  general_options.m_variables = {"a", "d", "e", "DOESNOTEXIST"};
290  try {
291  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
292  } catch (...) {
293 
294  }
295  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
296 
297  general_options.m_variables = {"a", "d", "e"};
298  general_options.m_spectators = {"DOESNOTEXIST"};
299  try {
300  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
301  } catch (...) {
302 
303  }
304  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
305 
306  }
307 
308  TEST(DatasetTest, CombinedDataset)
309  {
310 
311  MVA::GeneralOptions general_options;
312  general_options.m_signal_class = 1;
313  general_options.m_variables = {"a", "b", "c", "d", "e"};
314  general_options.m_spectators = {"f", "g"};
315  TestDataset signal_dataset(general_options);
316  TestDataset bckgrd_dataset(general_options);
317 
318  MVA::CombinedDataset x(general_options, signal_dataset, bckgrd_dataset);
319 
320  EXPECT_EQ(x.getNumberOfFeatures(), 5);
321  EXPECT_EQ(x.getNumberOfEvents(), 40);
322 
323  // Should just work
324  x.loadEvent(0);
325 
326  EXPECT_EQ(x.m_input.size(), 5);
327  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
328  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
329  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
330  EXPECT_FLOAT_EQ(x.m_input[3], 4.0);
331  EXPECT_FLOAT_EQ(x.m_input[4], 5.0);
332 
333  EXPECT_EQ(x.m_spectators.size(), 2);
334  EXPECT_FLOAT_EQ(x.m_spectators[0], 6.0);
335  EXPECT_FLOAT_EQ(x.m_spectators[1], 7.0);
336 
337  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
338  EXPECT_FLOAT_EQ(x.m_target, 1.0);
339  EXPECT_EQ(x.m_isSignal, true);
340 
341  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.5);
342 
343  auto feature = x.getFeature(1);
344  EXPECT_EQ(feature.size(), 40);
345  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
346  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
347  };
348 
349  auto spectator = x.getSpectator(0);
350  EXPECT_EQ(spectator.size(), 40);
351  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
352  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
353  };
354 
355  // Same result for mother class implementation
356  feature = x.Dataset::getFeature(1);
357  EXPECT_EQ(feature.size(), 40);
358  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
359  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
360  };
361 
362  spectator = x.Dataset::getSpectator(0);
363  EXPECT_EQ(spectator.size(), 40);
364  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
365  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
366  };
367 
368  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
369  x.loadEvent(iEvent);
370  EXPECT_EQ(x.m_isSignal, iEvent < 20);
371  EXPECT_FLOAT_EQ(x.m_target, iEvent < 20 ? 1.0 : 0.0);
372  }
373 
374  }
375 
376  TEST(DatasetTest, ROOTDataset)
377  {
378 
380  TFile file("datafile.root", "RECREATE");
381  file.cd();
382  TTree tree("tree", "TreeTitle");
383  float a, b, c, d, e, f, v, w = 0;
384  bool target;
385  tree.Branch("a", &a);
386  tree.Branch("b", &b);
387  tree.Branch("c", &c);
388  tree.Branch("d", &d);
389  tree.Branch("e__bo__bc", &e);
390  tree.Branch("f__bo__bc", &f);
391  tree.Branch("g", &target);
392  tree.Branch("__weight__", &c);
393  tree.Branch("v__bo__bc", &v);
394  tree.Branch("w", &w);
395 
396  for (unsigned int i = 0; i < 5; ++i) {
397  a = i + 1.0;
398  b = i + 1.1;
399  c = i + 1.2;
400  d = i + 1.3;
401  e = i + 1.4;
402  f = i + 1.5;
403  target = (i % 2 == 0);
404  w = i + 1.6;
405  v = i + 1.7;
406  tree.Fill();
407  }
408 
409  file.Write("tree");
410 
411  MVA::GeneralOptions general_options;
412  // Both names with and without makeROOTCompatible should work
413  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
414  general_options.m_spectators = {"w", "v()"};
415  general_options.m_signal_class = 1;
416  general_options.m_datafiles = {"datafile.root"};
417  general_options.m_treename = "tree";
418  general_options.m_target_variable = "g";
419  general_options.m_weight_variable = "c";
420  MVA::ROOTDataset x(general_options);
421 
422  EXPECT_EQ(x.getNumberOfFeatures(), 4);
423  EXPECT_EQ(x.getNumberOfSpectators(), 2);
424  EXPECT_EQ(x.getNumberOfEvents(), 5);
425 
426  // Should just work
427  x.loadEvent(0);
428  EXPECT_EQ(x.m_input.size(), 4);
429  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
430  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
431  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
432  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
433  EXPECT_EQ(x.m_spectators.size(), 2);
434  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
435  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
436  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
437  EXPECT_EQ(x.m_target, true);
438  EXPECT_EQ(x.m_isSignal, true);
439 
440  x.loadEvent(1);
441  EXPECT_EQ(x.m_input.size(), 4);
442  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
443  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
444  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
445  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
446  EXPECT_EQ(x.m_spectators.size(), 2);
447  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
448  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
449  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
450  EXPECT_EQ(x.m_target, false);
451  EXPECT_EQ(x.m_isSignal, false);
452 
453  x.loadEvent(2);
454  EXPECT_EQ(x.m_input.size(), 4);
455  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
456  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
457  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
458  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
459  EXPECT_EQ(x.m_spectators.size(), 2);
460  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
461  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
462  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
463  EXPECT_EQ(x.m_target, true);
464  EXPECT_EQ(x.m_isSignal, true);
465 
466  x.loadEvent(3);
467  EXPECT_EQ(x.m_input.size(), 4);
468  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
469  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
470  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
471  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
472  EXPECT_EQ(x.m_spectators.size(), 2);
473  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
474  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
475  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
476  EXPECT_EQ(x.m_target, false);
477  EXPECT_EQ(x.m_isSignal, false);
478 
479  x.loadEvent(4);
480  EXPECT_EQ(x.m_input.size(), 4);
481  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
482  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
483  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
484  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
485  EXPECT_EQ(x.m_spectators.size(), 2);
486  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
487  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
488  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
489  EXPECT_EQ(x.m_target, true);
490  EXPECT_EQ(x.m_isSignal, true);
491 
492  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
493 
494  auto feature = x.getFeature(1);
495  EXPECT_EQ(feature.size(), 5);
496  EXPECT_FLOAT_EQ(feature[0], 1.1);
497  EXPECT_FLOAT_EQ(feature[1], 2.1);
498  EXPECT_FLOAT_EQ(feature[2], 3.1);
499  EXPECT_FLOAT_EQ(feature[3], 4.1);
500  EXPECT_FLOAT_EQ(feature[4], 5.1);
501 
502  // Same result for mother class implementation
503  feature = x.Dataset::getFeature(1);
504  EXPECT_EQ(feature.size(), 5);
505  EXPECT_FLOAT_EQ(feature[0], 1.1);
506  EXPECT_FLOAT_EQ(feature[1], 2.1);
507  EXPECT_FLOAT_EQ(feature[2], 3.1);
508  EXPECT_FLOAT_EQ(feature[3], 4.1);
509  EXPECT_FLOAT_EQ(feature[4], 5.1);
510 
511  auto spectator = x.getSpectator(1);
512  EXPECT_EQ(spectator.size(), 5);
513  EXPECT_FLOAT_EQ(spectator[0], 1.7);
514  EXPECT_FLOAT_EQ(spectator[1], 2.7);
515  EXPECT_FLOAT_EQ(spectator[2], 3.7);
516  EXPECT_FLOAT_EQ(spectator[3], 4.7);
517  EXPECT_FLOAT_EQ(spectator[4], 5.7);
518 
519  // Same result for mother class implementation
520  spectator = x.Dataset::getSpectator(1);
521  EXPECT_EQ(spectator.size(), 5);
522  EXPECT_FLOAT_EQ(spectator[0], 1.7);
523  EXPECT_FLOAT_EQ(spectator[1], 2.7);
524  EXPECT_FLOAT_EQ(spectator[2], 3.7);
525  EXPECT_FLOAT_EQ(spectator[3], 4.7);
526  EXPECT_FLOAT_EQ(spectator[4], 5.7);
527 
528  auto weights = x.getWeights();
529  EXPECT_EQ(weights.size(), 5);
530  EXPECT_FLOAT_EQ(weights[0], 1.2);
531  EXPECT_FLOAT_EQ(weights[1], 2.2);
532  EXPECT_FLOAT_EQ(weights[2], 3.2);
533  EXPECT_FLOAT_EQ(weights[3], 4.2);
534  EXPECT_FLOAT_EQ(weights[4], 5.2);
535 
536  // Same result for mother class implementation
537  weights = x.Dataset::getWeights();
538  EXPECT_EQ(weights.size(), 5);
539  EXPECT_FLOAT_EQ(weights[0], 1.2);
540  EXPECT_FLOAT_EQ(weights[1], 2.2);
541  EXPECT_FLOAT_EQ(weights[2], 3.2);
542  EXPECT_FLOAT_EQ(weights[3], 4.2);
543  EXPECT_FLOAT_EQ(weights[4], 5.2);
544 
545  auto targets = x.getTargets();
546  EXPECT_EQ(targets.size(), 5);
547  EXPECT_EQ(targets[0], true);
548  EXPECT_EQ(targets[1], false);
549  EXPECT_EQ(targets[2], true);
550  EXPECT_EQ(targets[3], false);
551  EXPECT_EQ(targets[4], true);
552 
553  auto signals = x.getSignals();
554  EXPECT_EQ(signals.size(), 5);
555  EXPECT_EQ(signals[0], true);
556  EXPECT_EQ(signals[1], false);
557  EXPECT_EQ(signals[2], true);
558  EXPECT_EQ(signals[3], false);
559  EXPECT_EQ(signals[4], true);
560 
561  // Using __weight__ should work as well,
562  // the only difference to using _weight__ instead of c is
563  // in setBranchAddresses which avoids calling makeROOTCompatible
564  // So we have to check the behaviour using __weight__ as well
565  general_options.m_weight_variable = "__weight__";
566  MVA::ROOTDataset y(general_options);
567 
568  weights = y.getWeights();
569  EXPECT_EQ(weights.size(), 5);
570  EXPECT_FLOAT_EQ(weights[0], 1.2);
571  EXPECT_FLOAT_EQ(weights[1], 2.2);
572  EXPECT_FLOAT_EQ(weights[2], 3.2);
573  EXPECT_FLOAT_EQ(weights[3], 4.2);
574  EXPECT_FLOAT_EQ(weights[4], 5.2);
575 
576  // Check TChain expansion
577  general_options.m_datafiles = {"datafile*.root"};
578  {
579  MVA::ROOTDataset chain_test(general_options);
580  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
581  }
582  std::filesystem::copy_file("datafile.root", "datafile2.root");
583  {
584  MVA::ROOTDataset chain_test(general_options);
585  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
586  }
587  std::filesystem::copy_file("datafile.root", "datafile3.root");
588  {
589  MVA::ROOTDataset chain_test(general_options);
590  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
591  }
592  // Test m_max_events feature
593  {
594  general_options.m_max_events = 10;
595  MVA::ROOTDataset chain_test(general_options);
596  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
597  general_options.m_max_events = 0;
598  }
599 
600  // Check for missing tree
601  general_options.m_treename = "missing tree";
602  try {
603  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
604  } catch (...) {
605 
606  }
607  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
608 
609  // Check for missing branch
610  general_options.m_treename = "tree";
611  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
612  try {
613  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
614  } catch (...) {
615 
616  }
617  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
618 
619  // Check for missing branch
620  general_options.m_treename = "tree";
621  general_options.m_variables = {"a", "b", "e", "f"};
622  general_options.m_spectators = {"missing branch"};
623  try {
624  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
625  } catch (...) {
626 
627  }
628  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
629 
630  // Check for missing file
631  general_options.m_spectators = {};
632  general_options.m_datafiles = {"DOESNOTEXIST.root"};
633  general_options.m_treename = "tree";
634  try {
635  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
636  } catch (...) {
637 
638  }
639  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
640 
641  // Check for invalid file
642  general_options.m_datafiles = {"ISNotAValidROOTFile"};
643  general_options.m_treename = "tree";
644 
645  {
646  std::ofstream(general_options.m_datafiles[0]);
647  }
648  EXPECT_TRUE(std::filesystem::exists(general_options.m_datafiles[0]));
649 
650  try {
651  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
652  } catch (...) {
653 
654  }
655  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
656  }
657 
658 
659  TEST(DatasetTest, ROOTDatasetMixedTypes)
660  {
661 
663  TFile file("datafile.root", "RECREATE");
664  file.cd();
665  TTree tree("tree", "TreeTitle");
666  double a, c, d, g, v, w = 0;
667  float b = 0.0f;
668  int e = 0;
669  bool f = false;
670  tree.Branch("a", &a, "a/D");
671  tree.Branch("b", &b, "b/F");
672  tree.Branch("c", &c, "c/D");
673  tree.Branch("d", &d, "d/D");
674  tree.Branch("e__bo__bc", &e, "e__bo__bc/I");
675  tree.Branch("f__bo__bc", &f, "f__bo__bc/O");
676  tree.Branch("g", &g, "g/D");
677  tree.Branch("__weight__", &c, "__weight__/D");
678  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
679  tree.Branch("w", &w, "w/D");
680 
681  for (unsigned int i = 0; i < 5; ++i) {
682  a = i + 1.0;
683  b = i + 1.1;
684  c = i + 1.2;
685  d = i + 1.3;
686  e = i + 1;
687  f = i % 2 == 0;
688  g = float(i % 2 == 0);
689  w = i + 1.6;
690  v = i + 1.7;
691  tree.Fill();
692  }
693 
694  file.Write("tree");
695 
696  MVA::GeneralOptions general_options;
697  // Both names with and without makeROOTCompatible should work
698  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
699  general_options.m_spectators = {"w", "v()"};
700  general_options.m_signal_class = 1;
701  general_options.m_datafiles = {"datafile.root"};
702  general_options.m_treename = "tree";
703  general_options.m_target_variable = "g";
704  general_options.m_weight_variable = "c";
705  MVA::ROOTDataset x(general_options);
706 
707  EXPECT_EQ(x.getNumberOfFeatures(), 4);
708  EXPECT_EQ(x.getNumberOfSpectators(), 2);
709  EXPECT_EQ(x.getNumberOfEvents(), 5);
710 
711  // Should just work
712  x.loadEvent(0);
713  EXPECT_EQ(x.m_input.size(), 4);
714  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
715  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
716  EXPECT_EQ(x.m_input[2], 1);
717  EXPECT_EQ(x.m_input[3], true);
718  EXPECT_EQ(x.m_spectators.size(), 2);
719  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
720  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
721  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
722  EXPECT_FLOAT_EQ(x.m_target, 1.0);
723  EXPECT_EQ(x.m_isSignal, true);
724 
725  x.loadEvent(1);
726  EXPECT_EQ(x.m_input.size(), 4);
727  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
728  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
729  EXPECT_EQ(x.m_input[2], 2);
730  EXPECT_EQ(x.m_input[3], false);
731  EXPECT_EQ(x.m_spectators.size(), 2);
732  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
733  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
734  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
735  EXPECT_FLOAT_EQ(x.m_target, 0.0);
736  EXPECT_EQ(x.m_isSignal, false);
737 
738  x.loadEvent(2);
739  EXPECT_EQ(x.m_input.size(), 4);
740  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
741  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
742  EXPECT_EQ(x.m_input[2], 3);
743  EXPECT_EQ(x.m_input[3], true);
744  EXPECT_EQ(x.m_spectators.size(), 2);
745  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
746  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
747  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
748  EXPECT_FLOAT_EQ(x.m_target, 1.0);
749  EXPECT_EQ(x.m_isSignal, true);
750 
751  x.loadEvent(3);
752  EXPECT_EQ(x.m_input.size(), 4);
753  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
754  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
755  EXPECT_EQ(x.m_input[2], 4);
756  EXPECT_EQ(x.m_input[3], false);
757  EXPECT_EQ(x.m_spectators.size(), 2);
758  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
759  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
760  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
761  EXPECT_FLOAT_EQ(x.m_target, 0.0);
762  EXPECT_EQ(x.m_isSignal, false);
763 
764  x.loadEvent(4);
765  EXPECT_EQ(x.m_input.size(), 4);
766  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
767  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
768  EXPECT_EQ(x.m_input[2], 5);
769  EXPECT_EQ(x.m_input[3], true);
770  EXPECT_EQ(x.m_spectators.size(), 2);
771  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
772  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
773  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
774  EXPECT_FLOAT_EQ(x.m_target, 1.0);
775  EXPECT_EQ(x.m_isSignal, true);
776 
777  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
778 
779  auto feature = x.getFeature(1);
780  EXPECT_EQ(feature.size(), 5);
781  EXPECT_FLOAT_EQ(feature[0], 1.1);
782  EXPECT_FLOAT_EQ(feature[1], 2.1);
783  EXPECT_FLOAT_EQ(feature[2], 3.1);
784  EXPECT_FLOAT_EQ(feature[3], 4.1);
785  EXPECT_FLOAT_EQ(feature[4], 5.1);
786 
787  // Same result for mother class implementation
788  feature = x.Dataset::getFeature(1);
789  EXPECT_EQ(feature.size(), 5);
790  EXPECT_FLOAT_EQ(feature[0], 1.1);
791  EXPECT_FLOAT_EQ(feature[1], 2.1);
792  EXPECT_FLOAT_EQ(feature[2], 3.1);
793  EXPECT_FLOAT_EQ(feature[3], 4.1);
794  EXPECT_FLOAT_EQ(feature[4], 5.1);
795 
796  auto spectator = x.getSpectator(1);
797  EXPECT_EQ(spectator.size(), 5);
798  EXPECT_FLOAT_EQ(spectator[0], 1.7);
799  EXPECT_FLOAT_EQ(spectator[1], 2.7);
800  EXPECT_FLOAT_EQ(spectator[2], 3.7);
801  EXPECT_FLOAT_EQ(spectator[3], 4.7);
802  EXPECT_FLOAT_EQ(spectator[4], 5.7);
803 
804  // Same result for mother class implementation
805  spectator = x.Dataset::getSpectator(1);
806  EXPECT_EQ(spectator.size(), 5);
807  EXPECT_FLOAT_EQ(spectator[0], 1.7);
808  EXPECT_FLOAT_EQ(spectator[1], 2.7);
809  EXPECT_FLOAT_EQ(spectator[2], 3.7);
810  EXPECT_FLOAT_EQ(spectator[3], 4.7);
811  EXPECT_FLOAT_EQ(spectator[4], 5.7);
812 
813  auto weights = x.getWeights();
814  EXPECT_EQ(weights.size(), 5);
815  EXPECT_FLOAT_EQ(weights[0], 1.2);
816  EXPECT_FLOAT_EQ(weights[1], 2.2);
817  EXPECT_FLOAT_EQ(weights[2], 3.2);
818  EXPECT_FLOAT_EQ(weights[3], 4.2);
819  EXPECT_FLOAT_EQ(weights[4], 5.2);
820 
821  // Same result for mother class implementation
822  weights = x.Dataset::getWeights();
823  EXPECT_EQ(weights.size(), 5);
824  EXPECT_FLOAT_EQ(weights[0], 1.2);
825  EXPECT_FLOAT_EQ(weights[1], 2.2);
826  EXPECT_FLOAT_EQ(weights[2], 3.2);
827  EXPECT_FLOAT_EQ(weights[3], 4.2);
828  EXPECT_FLOAT_EQ(weights[4], 5.2);
829 
830  auto targets = x.getTargets();
831  EXPECT_EQ(targets.size(), 5);
832  EXPECT_FLOAT_EQ(targets[0], 1.0);
833  EXPECT_FLOAT_EQ(targets[1], 0.0);
834  EXPECT_FLOAT_EQ(targets[2], 1.0);
835  EXPECT_FLOAT_EQ(targets[3], 0.0);
836  EXPECT_FLOAT_EQ(targets[4], 1.0);
837 
838  auto signals = x.getSignals();
839  EXPECT_EQ(signals.size(), 5);
840  EXPECT_EQ(signals[0], true);
841  EXPECT_EQ(signals[1], false);
842  EXPECT_EQ(signals[2], true);
843  EXPECT_EQ(signals[3], false);
844  EXPECT_EQ(signals[4], true);
845 
846  // Using __weight__ should work as well,
847  // the only difference to using _weight__ instead of g is
848  // in setBranchAddresses which avoids calling makeROOTCompatible
849  // So we have to check the behaviour using __weight__ as well
850  general_options.m_weight_variable = "__weight__";
851  MVA::ROOTDataset y(general_options);
852 
853  weights = y.getWeights();
854  EXPECT_EQ(weights.size(), 5);
855  EXPECT_FLOAT_EQ(weights[0], 1.2);
856  EXPECT_FLOAT_EQ(weights[1], 2.2);
857  EXPECT_FLOAT_EQ(weights[2], 3.2);
858  EXPECT_FLOAT_EQ(weights[3], 4.2);
859  EXPECT_FLOAT_EQ(weights[4], 5.2);
860 
861  // Check TChain expansion
862  general_options.m_datafiles = {"datafile*.root"};
863  {
864  MVA::ROOTDataset chain_test(general_options);
865  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
866  }
867  std::filesystem::copy_file("datafile.root", "datafile2.root");
868  {
869  MVA::ROOTDataset chain_test(general_options);
870  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
871  }
872  std::filesystem::copy_file("datafile.root", "datafile3.root");
873  {
874  MVA::ROOTDataset chain_test(general_options);
875  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
876  }
877  // Test m_max_events feature
878  {
879  general_options.m_max_events = 10;
880  MVA::ROOTDataset chain_test(general_options);
881  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
882  general_options.m_max_events = 0;
883  }
884 
885  // Check for missing tree
886  general_options.m_treename = "missing tree";
887  try {
888  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
889  } catch (...) {
890 
891  }
892  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
893 
894  // Check for missing branch
895  general_options.m_treename = "tree";
896  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
897  try {
898  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
899  } catch (...) {
900 
901  }
902  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
903 
904  // Check for missing branch
905  general_options.m_treename = "tree";
906  general_options.m_variables = {"a", "b", "e", "f"};
907  general_options.m_spectators = {"missing branch"};
908  try {
909  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
910  } catch (...) {
911 
912  }
913  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
914 
915  // Check for missing file
916  general_options.m_spectators = {};
917  general_options.m_datafiles = {"DOESNOTEXIST.root"};
918  general_options.m_treename = "tree";
919  try {
920  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
921  } catch (...) {
922 
923  }
924  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
925 
926  // Check for invalid file
927  general_options.m_datafiles = {"ISNotAValidROOTFile"};
928  general_options.m_treename = "tree";
929 
930  {
931  std::ofstream(general_options.m_datafiles[0]);
932  }
933  EXPECT_TRUE(std::filesystem::exists(general_options.m_datafiles[0]));
934 
935  try {
936  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
937  } catch (...) {
938 
939  }
940  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
941  }
942 
943  TEST(DatasetTest, ROOTMultiDataset)
944  {
945 
947  TFile file("datafile.root", "RECREATE");
948  file.cd();
949  TTree tree("tree", "TreeTitle");
950  double a, b, c, d, e, f, v, w = 0;
951  bool target;
952  tree.Branch("a", &a);
953  tree.Branch("b", &b);
954  tree.Branch("c", &c);
955  tree.Branch("d", &d);
956  tree.Branch("e__bo__bc", &e);
957  tree.Branch("f__bo__bc", &f);
958  tree.Branch("g", &target);
959  tree.Branch("__weight__", &c);
960  tree.Branch("v__bo__bc", &v);
961  tree.Branch("w", &w);
962 
963  for (unsigned int i = 0; i < 5; ++i) {
964  a = i + 1.0;
965  b = i + 1.1;
966  c = i + 1.2;
967  d = i + 1.3;
968  e = i + 1.4;
969  f = i + 1.5;
970  target = (i % 2 == 0);
971  w = i + 1.6;
972  v = i + 1.7;
973  tree.Fill();
974  }
975 
976  file.Write("tree");
977 
978  TFile file2("datafile2.root", "RECREATE");
979  file2.cd();
980  TTree tree2("tree", "TreeTitle");
981  tree2.Branch("a", &a);
982  tree2.Branch("b", &b);
983  tree2.Branch("c", &c);
984  tree2.Branch("d", &d);
985  tree2.Branch("e__bo__bc", &e);
986  tree2.Branch("f__bo__bc", &f);
987  tree2.Branch("g", &target);
988  tree2.Branch("__weight__", &c);
989  tree2.Branch("v__bo__bc", &v);
990  tree2.Branch("w", &w);
991 
992  for (unsigned int i = 0; i < 5; ++i) {
993  a = i + 1.0;
994  b = i + 1.1;
995  c = i + 1.2;
996  d = i + 1.3;
997  e = i + 1.4;
998  f = i + 1.5;
999  target = (i % 2 == 0);
1000  w = i + 1.6;
1001  v = i + 1.7;
1002  tree2.Fill();
1003  }
1004 
1005  file2.Write("tree");
1006 
1007  MVA::GeneralOptions general_options;
1008  // Both names with and without makeROOTCompatible should work
1009  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1010  general_options.m_spectators = {"w", "v()"};
1011  general_options.m_signal_class = 1;
1012  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1013  general_options.m_treename = "tree";
1014  general_options.m_target_variable = "g";
1015  general_options.m_weight_variable = "c";
1016  MVA::ROOTDataset x(general_options);
1017 
1018  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1019  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1020  EXPECT_EQ(x.getNumberOfEvents(), 10);
1021 
1022  // Should just work
1023  x.loadEvent(0);
1024  EXPECT_EQ(x.m_input.size(), 4);
1025  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1026  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1027  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1028  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1029  EXPECT_EQ(x.m_spectators.size(), 2);
1030  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1031  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1032  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1033  EXPECT_EQ(x.m_target, true);
1034  EXPECT_EQ(x.m_isSignal, true);
1035 
1036  x.loadEvent(5);
1037  EXPECT_EQ(x.m_input.size(), 4);
1038  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1039  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1040  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1041  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1042  EXPECT_EQ(x.m_spectators.size(), 2);
1043  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1044  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1045  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1046  EXPECT_EQ(x.m_target, true);
1047  EXPECT_EQ(x.m_isSignal, true);
1048 
1049  x.loadEvent(1);
1050  EXPECT_EQ(x.m_input.size(), 4);
1051  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1052  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1053  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1054  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1055  EXPECT_EQ(x.m_spectators.size(), 2);
1056  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1057  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1058  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1059  EXPECT_EQ(x.m_target, false);
1060  EXPECT_EQ(x.m_isSignal, false);
1061 
1062  x.loadEvent(6);
1063  EXPECT_EQ(x.m_input.size(), 4);
1064  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1065  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1066  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1067  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1068  EXPECT_EQ(x.m_spectators.size(), 2);
1069  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1070  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1071  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1072  EXPECT_EQ(x.m_target, false);
1073  EXPECT_EQ(x.m_isSignal, false);
1074 
1075  x.loadEvent(2);
1076  EXPECT_EQ(x.m_input.size(), 4);
1077  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1078  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1079  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1080  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1081  EXPECT_EQ(x.m_spectators.size(), 2);
1082  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1083  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1084  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1085  EXPECT_EQ(x.m_target, true);
1086  EXPECT_EQ(x.m_isSignal, true);
1087 
1088  x.loadEvent(7);
1089  EXPECT_EQ(x.m_input.size(), 4);
1090  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1091  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1092  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1093  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1094  EXPECT_EQ(x.m_spectators.size(), 2);
1095  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1096  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1097  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1098  EXPECT_EQ(x.m_target, true);
1099  EXPECT_EQ(x.m_isSignal, true);
1100 
1101  x.loadEvent(3);
1102  EXPECT_EQ(x.m_input.size(), 4);
1103  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1104  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1105  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1106  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1107  EXPECT_EQ(x.m_spectators.size(), 2);
1108  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1109  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1110  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1111  EXPECT_EQ(x.m_target, false);
1112  EXPECT_EQ(x.m_isSignal, false);
1113 
1114  x.loadEvent(8);
1115  EXPECT_EQ(x.m_input.size(), 4);
1116  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1117  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1118  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1119  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1120  EXPECT_EQ(x.m_spectators.size(), 2);
1121  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1122  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1123  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1124  EXPECT_EQ(x.m_target, false);
1125  EXPECT_EQ(x.m_isSignal, false);
1126 
1127  x.loadEvent(4);
1128  EXPECT_EQ(x.m_input.size(), 4);
1129  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1130  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1131  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1132  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1133  EXPECT_EQ(x.m_spectators.size(), 2);
1134  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1135  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1136  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1137  EXPECT_EQ(x.m_target, true);
1138  EXPECT_EQ(x.m_isSignal, true);
1139 
1140  x.loadEvent(9);
1141  EXPECT_EQ(x.m_input.size(), 4);
1142  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1143  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1144  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1145  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1146  EXPECT_EQ(x.m_spectators.size(), 2);
1147  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1148  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1149  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1150  EXPECT_EQ(x.m_target, true);
1151  EXPECT_EQ(x.m_isSignal, true);
1152 
1153  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1154 
1155  auto feature = x.getFeature(1);
1156  EXPECT_EQ(feature.size(), 10);
1157  EXPECT_FLOAT_EQ(feature[0], 1.1);
1158  EXPECT_FLOAT_EQ(feature[1], 2.1);
1159  EXPECT_FLOAT_EQ(feature[2], 3.1);
1160  EXPECT_FLOAT_EQ(feature[3], 4.1);
1161  EXPECT_FLOAT_EQ(feature[4], 5.1);
1162  EXPECT_FLOAT_EQ(feature[5], 1.1);
1163  EXPECT_FLOAT_EQ(feature[6], 2.1);
1164  EXPECT_FLOAT_EQ(feature[7], 3.1);
1165  EXPECT_FLOAT_EQ(feature[8], 4.1);
1166  EXPECT_FLOAT_EQ(feature[9], 5.1);
1167 
1168  // Same result for mother class implementation
1169  feature = x.Dataset::getFeature(1);
1170  EXPECT_EQ(feature.size(), 10);
1171  EXPECT_FLOAT_EQ(feature[0], 1.1);
1172  EXPECT_FLOAT_EQ(feature[1], 2.1);
1173  EXPECT_FLOAT_EQ(feature[2], 3.1);
1174  EXPECT_FLOAT_EQ(feature[3], 4.1);
1175  EXPECT_FLOAT_EQ(feature[4], 5.1);
1176  EXPECT_FLOAT_EQ(feature[5], 1.1);
1177  EXPECT_FLOAT_EQ(feature[6], 2.1);
1178  EXPECT_FLOAT_EQ(feature[7], 3.1);
1179  EXPECT_FLOAT_EQ(feature[8], 4.1);
1180  EXPECT_FLOAT_EQ(feature[9], 5.1);
1181 
1182  auto spectator = x.getSpectator(1);
1183  EXPECT_EQ(spectator.size(), 10);
1184  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1185  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1186  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1187  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1188  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1189  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1190  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1191  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1192  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1193  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1194 
1195  // Same result for mother class implementation
1196  spectator = x.Dataset::getSpectator(1);
1197  EXPECT_EQ(spectator.size(), 10);
1198  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1199  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1200  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1201  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1202  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1203  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1204  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1205  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1206  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1207  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1208 
1209  auto weights = x.getWeights();
1210  EXPECT_EQ(weights.size(), 10);
1211  EXPECT_FLOAT_EQ(weights[0], 1.2);
1212  EXPECT_FLOAT_EQ(weights[1], 2.2);
1213  EXPECT_FLOAT_EQ(weights[2], 3.2);
1214  EXPECT_FLOAT_EQ(weights[3], 4.2);
1215  EXPECT_FLOAT_EQ(weights[4], 5.2);
1216  EXPECT_FLOAT_EQ(weights[5], 1.2);
1217  EXPECT_FLOAT_EQ(weights[6], 2.2);
1218  EXPECT_FLOAT_EQ(weights[7], 3.2);
1219  EXPECT_FLOAT_EQ(weights[8], 4.2);
1220  EXPECT_FLOAT_EQ(weights[9], 5.2);
1221 
1222  // Same result for mother class implementation
1223  weights = x.Dataset::getWeights();
1224  EXPECT_EQ(weights.size(), 10);
1225  EXPECT_FLOAT_EQ(weights[0], 1.2);
1226  EXPECT_FLOAT_EQ(weights[1], 2.2);
1227  EXPECT_FLOAT_EQ(weights[2], 3.2);
1228  EXPECT_FLOAT_EQ(weights[3], 4.2);
1229  EXPECT_FLOAT_EQ(weights[4], 5.2);
1230  EXPECT_FLOAT_EQ(weights[5], 1.2);
1231  EXPECT_FLOAT_EQ(weights[6], 2.2);
1232  EXPECT_FLOAT_EQ(weights[7], 3.2);
1233  EXPECT_FLOAT_EQ(weights[8], 4.2);
1234  EXPECT_FLOAT_EQ(weights[9], 5.2);
1235 
1236  auto targets = x.getTargets();
1237  EXPECT_EQ(targets.size(), 10);
1238  EXPECT_EQ(targets[0], true);
1239  EXPECT_EQ(targets[1], false);
1240  EXPECT_EQ(targets[2], true);
1241  EXPECT_EQ(targets[3], false);
1242  EXPECT_EQ(targets[4], true);
1243  EXPECT_EQ(targets[5], true);
1244  EXPECT_EQ(targets[6], false);
1245  EXPECT_EQ(targets[7], true);
1246  EXPECT_EQ(targets[8], false);
1247  EXPECT_EQ(targets[9], true);
1248 
1249  auto signals = x.getSignals();
1250  EXPECT_EQ(signals.size(), 10);
1251  EXPECT_EQ(signals[0], true);
1252  EXPECT_EQ(signals[1], false);
1253  EXPECT_EQ(signals[2], true);
1254  EXPECT_EQ(signals[3], false);
1255  EXPECT_EQ(signals[4], true);
1256  EXPECT_EQ(signals[5], true);
1257  EXPECT_EQ(signals[6], false);
1258  EXPECT_EQ(signals[7], true);
1259  EXPECT_EQ(signals[8], false);
1260  EXPECT_EQ(signals[9], true);
1261 
1262  // Using __weight__ should work as well,
1263  // the only difference to using _weight__ instead of c is
1264  // in setBranchAddresses which avoids calling makeROOTCompatible
1265  // So we have to check the behaviour using __weight__ as well
1266  general_options.m_weight_variable = "__weight__";
1267  MVA::ROOTDataset y(general_options);
1268 
1269  weights = y.getWeights();
1270  EXPECT_EQ(weights.size(), 10);
1271  EXPECT_FLOAT_EQ(weights[0], 1.2);
1272  EXPECT_FLOAT_EQ(weights[1], 2.2);
1273  EXPECT_FLOAT_EQ(weights[2], 3.2);
1274  EXPECT_FLOAT_EQ(weights[3], 4.2);
1275  EXPECT_FLOAT_EQ(weights[4], 5.2);
1276  EXPECT_FLOAT_EQ(weights[5], 1.2);
1277  EXPECT_FLOAT_EQ(weights[6], 2.2);
1278  EXPECT_FLOAT_EQ(weights[7], 3.2);
1279  EXPECT_FLOAT_EQ(weights[8], 4.2);
1280  EXPECT_FLOAT_EQ(weights[9], 5.2);
1281 
1282  // Check TChain expansion
1283  general_options.m_datafiles = {"datafile*.root"};
1284  {
1285  MVA::ROOTDataset chain_test(general_options);
1286  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1287  }
1288  std::filesystem::copy_file("datafile.root", "datafile3.root");
1289  {
1290  MVA::ROOTDataset chain_test(general_options);
1291  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1292  }
1293  std::filesystem::copy_file("datafile.root", "datafile4.root");
1294  {
1295  MVA::ROOTDataset chain_test(general_options);
1296  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1297  }
1298  // Test m_max_events feature
1299  {
1300  general_options.m_max_events = 10;
1301  MVA::ROOTDataset chain_test(general_options);
1302  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1303  general_options.m_max_events = 0;
1304  }
1305 
1306  // If a file exists with the specified expansion
1307  // the file takes precedence over the expansion
1308  std::filesystem::copy_file("datafile.root", "datafile*.root");
1309  {
1310  general_options.m_max_events = 0;
1311  MVA::ROOTDataset chain_test(general_options);
1312  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1313  }
1314 
1315  }
1316  TEST(DatasetTest, ROOTMultiDatasetDouble)
1317  {
1318 
1320  TFile file("datafile.root", "RECREATE");
1321  file.cd();
1322  TTree tree("tree", "TreeTitle");
1323  double a, b, c, d, e, f, g, v, w = 0;
1324  tree.Branch("a", &a, "a/D");
1325  tree.Branch("b", &b, "b/D");
1326  tree.Branch("c", &c, "c/D");
1327  tree.Branch("d", &d, "d/D");
1328  tree.Branch("e__bo__bc", &e, "e__bo__bc/D");
1329  tree.Branch("f__bo__bc", &f, "f__bo__bc/D");
1330  tree.Branch("g", &g, "g/D");
1331  tree.Branch("__weight__", &c, "__weight__/D");
1332  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
1333  tree.Branch("w", &w, "w/D");
1334 
1335  for (unsigned int i = 0; i < 5; ++i) {
1336  a = i + 1.0;
1337  b = i + 1.1;
1338  c = i + 1.2;
1339  d = i + 1.3;
1340  e = i + 1.4;
1341  f = i + 1.5;
1342  g = float(i % 2 == 0);
1343  w = i + 1.6;
1344  v = i + 1.7;
1345  tree.Fill();
1346  }
1347 
1348  file.Write("tree");
1349 
1350  TFile file2("datafile2.root", "RECREATE");
1351  file2.cd();
1352  TTree tree2("tree", "TreeTitle");
1353  tree2.Branch("a", &a);
1354  tree2.Branch("b", &b);
1355  tree2.Branch("c", &c);
1356  tree2.Branch("d", &d);
1357  tree2.Branch("e__bo__bc", &e);
1358  tree2.Branch("f__bo__bc", &f);
1359  tree2.Branch("g", &g);
1360  tree2.Branch("__weight__", &c);
1361  tree2.Branch("v__bo__bc", &v);
1362  tree2.Branch("w", &w);
1363 
1364  for (unsigned int i = 0; i < 5; ++i) {
1365  a = i + 1.0;
1366  b = i + 1.1;
1367  c = i + 1.2;
1368  d = i + 1.3;
1369  e = i + 1.4;
1370  f = i + 1.5;
1371  g = float(i % 2 == 0);
1372  w = i + 1.6;
1373  v = i + 1.7;
1374  tree2.Fill();
1375  }
1376 
1377  file2.Write("tree");
1378 
1379  MVA::GeneralOptions general_options;
1380  // Both names with and without makeROOTCompatible should work
1381  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1382  general_options.m_spectators = {"w", "v()"};
1383  general_options.m_signal_class = 1;
1384  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1385  general_options.m_treename = "tree";
1386  general_options.m_target_variable = "g";
1387  general_options.m_weight_variable = "c";
1388  MVA::ROOTDataset x(general_options);
1389 
1390  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1391  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1392  EXPECT_EQ(x.getNumberOfEvents(), 10);
1393 
1394  // Should just work
1395  x.loadEvent(0);
1396  EXPECT_EQ(x.m_input.size(), 4);
1397  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1398  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1399  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1400  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1401  EXPECT_EQ(x.m_spectators.size(), 2);
1402  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1403  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1404  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1405  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1406  EXPECT_EQ(x.m_isSignal, true);
1407 
1408  x.loadEvent(5);
1409  EXPECT_EQ(x.m_input.size(), 4);
1410  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1411  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1412  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1413  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1414  EXPECT_EQ(x.m_spectators.size(), 2);
1415  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1416  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1417  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1418  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1419  EXPECT_EQ(x.m_isSignal, true);
1420 
1421  x.loadEvent(1);
1422  EXPECT_EQ(x.m_input.size(), 4);
1423  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1424  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1425  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1426  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1427  EXPECT_EQ(x.m_spectators.size(), 2);
1428  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1429  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1430  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1431  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1432  EXPECT_EQ(x.m_isSignal, false);
1433 
1434  x.loadEvent(6);
1435  EXPECT_EQ(x.m_input.size(), 4);
1436  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1437  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1438  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1439  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1440  EXPECT_EQ(x.m_spectators.size(), 2);
1441  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1442  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1443  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1444  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1445  EXPECT_EQ(x.m_isSignal, false);
1446 
1447  x.loadEvent(2);
1448  EXPECT_EQ(x.m_input.size(), 4);
1449  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1450  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1451  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1452  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1453  EXPECT_EQ(x.m_spectators.size(), 2);
1454  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1455  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1456  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1457  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1458  EXPECT_EQ(x.m_isSignal, true);
1459 
1460  x.loadEvent(7);
1461  EXPECT_EQ(x.m_input.size(), 4);
1462  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1463  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1464  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1465  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1466  EXPECT_EQ(x.m_spectators.size(), 2);
1467  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1468  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1469  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1470  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1471  EXPECT_EQ(x.m_isSignal, true);
1472 
1473  x.loadEvent(3);
1474  EXPECT_EQ(x.m_input.size(), 4);
1475  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1476  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1477  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1478  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1479  EXPECT_EQ(x.m_spectators.size(), 2);
1480  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1481  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1482  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1483  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1484  EXPECT_EQ(x.m_isSignal, false);
1485 
1486  x.loadEvent(8);
1487  EXPECT_EQ(x.m_input.size(), 4);
1488  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1489  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1490  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1491  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1492  EXPECT_EQ(x.m_spectators.size(), 2);
1493  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1494  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1495  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1496  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1497  EXPECT_EQ(x.m_isSignal, false);
1498 
1499  x.loadEvent(4);
1500  EXPECT_EQ(x.m_input.size(), 4);
1501  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1502  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1503  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1504  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1505  EXPECT_EQ(x.m_spectators.size(), 2);
1506  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1507  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1508  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1509  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1510  EXPECT_EQ(x.m_isSignal, true);
1511 
1512  x.loadEvent(9);
1513  EXPECT_EQ(x.m_input.size(), 4);
1514  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1515  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1516  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1517  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1518  EXPECT_EQ(x.m_spectators.size(), 2);
1519  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1520  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1521  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1522  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1523  EXPECT_EQ(x.m_isSignal, true);
1524 
1525  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1526 
1527  auto feature = x.getFeature(1);
1528  EXPECT_EQ(feature.size(), 10);
1529  EXPECT_FLOAT_EQ(feature[0], 1.1);
1530  EXPECT_FLOAT_EQ(feature[1], 2.1);
1531  EXPECT_FLOAT_EQ(feature[2], 3.1);
1532  EXPECT_FLOAT_EQ(feature[3], 4.1);
1533  EXPECT_FLOAT_EQ(feature[4], 5.1);
1534  EXPECT_FLOAT_EQ(feature[5], 1.1);
1535  EXPECT_FLOAT_EQ(feature[6], 2.1);
1536  EXPECT_FLOAT_EQ(feature[7], 3.1);
1537  EXPECT_FLOAT_EQ(feature[8], 4.1);
1538  EXPECT_FLOAT_EQ(feature[9], 5.1);
1539 
1540  // Same result for mother class implementation
1541  feature = x.Dataset::getFeature(1);
1542  EXPECT_EQ(feature.size(), 10);
1543  EXPECT_FLOAT_EQ(feature[0], 1.1);
1544  EXPECT_FLOAT_EQ(feature[1], 2.1);
1545  EXPECT_FLOAT_EQ(feature[2], 3.1);
1546  EXPECT_FLOAT_EQ(feature[3], 4.1);
1547  EXPECT_FLOAT_EQ(feature[4], 5.1);
1548  EXPECT_FLOAT_EQ(feature[5], 1.1);
1549  EXPECT_FLOAT_EQ(feature[6], 2.1);
1550  EXPECT_FLOAT_EQ(feature[7], 3.1);
1551  EXPECT_FLOAT_EQ(feature[8], 4.1);
1552  EXPECT_FLOAT_EQ(feature[9], 5.1);
1553 
1554  auto spectator = x.getSpectator(1);
1555  EXPECT_EQ(spectator.size(), 10);
1556  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1557  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1558  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1559  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1560  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1561  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1562  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1563  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1564  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1565  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1566 
1567  // Same result for mother class implementation
1568  spectator = x.Dataset::getSpectator(1);
1569  EXPECT_EQ(spectator.size(), 10);
1570  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1571  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1572  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1573  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1574  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1575  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1576  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1577  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1578  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1579  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1580 
1581  auto weights = x.getWeights();
1582  EXPECT_EQ(weights.size(), 10);
1583  EXPECT_FLOAT_EQ(weights[0], 1.2);
1584  EXPECT_FLOAT_EQ(weights[1], 2.2);
1585  EXPECT_FLOAT_EQ(weights[2], 3.2);
1586  EXPECT_FLOAT_EQ(weights[3], 4.2);
1587  EXPECT_FLOAT_EQ(weights[4], 5.2);
1588  EXPECT_FLOAT_EQ(weights[5], 1.2);
1589  EXPECT_FLOAT_EQ(weights[6], 2.2);
1590  EXPECT_FLOAT_EQ(weights[7], 3.2);
1591  EXPECT_FLOAT_EQ(weights[8], 4.2);
1592  EXPECT_FLOAT_EQ(weights[9], 5.2);
1593 
1594  // Same result for mother class implementation
1595  weights = x.Dataset::getWeights();
1596  EXPECT_EQ(weights.size(), 10);
1597  EXPECT_FLOAT_EQ(weights[0], 1.2);
1598  EXPECT_FLOAT_EQ(weights[1], 2.2);
1599  EXPECT_FLOAT_EQ(weights[2], 3.2);
1600  EXPECT_FLOAT_EQ(weights[3], 4.2);
1601  EXPECT_FLOAT_EQ(weights[4], 5.2);
1602  EXPECT_FLOAT_EQ(weights[5], 1.2);
1603  EXPECT_FLOAT_EQ(weights[6], 2.2);
1604  EXPECT_FLOAT_EQ(weights[7], 3.2);
1605  EXPECT_FLOAT_EQ(weights[8], 4.2);
1606  EXPECT_FLOAT_EQ(weights[9], 5.2);
1607 
1608  auto targets = x.getTargets();
1609  EXPECT_EQ(targets.size(), 10);
1610  EXPECT_FLOAT_EQ(targets[0], 1.0);
1611  EXPECT_FLOAT_EQ(targets[1], 0.0);
1612  EXPECT_FLOAT_EQ(targets[2], 1.0);
1613  EXPECT_FLOAT_EQ(targets[3], 0.0);
1614  EXPECT_FLOAT_EQ(targets[4], 1.0);
1615  EXPECT_FLOAT_EQ(targets[5], 1.0);
1616  EXPECT_FLOAT_EQ(targets[6], 0.0);
1617  EXPECT_FLOAT_EQ(targets[7], 1.0);
1618  EXPECT_FLOAT_EQ(targets[8], 0.0);
1619  EXPECT_FLOAT_EQ(targets[9], 1.0);
1620 
1621  auto signals = x.getSignals();
1622  EXPECT_EQ(signals.size(), 10);
1623  EXPECT_EQ(signals[0], true);
1624  EXPECT_EQ(signals[1], false);
1625  EXPECT_EQ(signals[2], true);
1626  EXPECT_EQ(signals[3], false);
1627  EXPECT_EQ(signals[4], true);
1628  EXPECT_EQ(signals[5], true);
1629  EXPECT_EQ(signals[6], false);
1630  EXPECT_EQ(signals[7], true);
1631  EXPECT_EQ(signals[8], false);
1632  EXPECT_EQ(signals[9], true);
1633 
1634  // Using __weight__ should work as well,
1635  // the only difference to using _weight__ instead of g is
1636  // in setBranchAddresses which avoids calling makeROOTCompatible
1637  // So we have to check the behaviour using __weight__ as well
1638  general_options.m_weight_variable = "__weight__";
1639  MVA::ROOTDataset y(general_options);
1640 
1641  weights = y.getWeights();
1642  EXPECT_EQ(weights.size(), 10);
1643  EXPECT_FLOAT_EQ(weights[0], 1.2);
1644  EXPECT_FLOAT_EQ(weights[1], 2.2);
1645  EXPECT_FLOAT_EQ(weights[2], 3.2);
1646  EXPECT_FLOAT_EQ(weights[3], 4.2);
1647  EXPECT_FLOAT_EQ(weights[4], 5.2);
1648  EXPECT_FLOAT_EQ(weights[5], 1.2);
1649  EXPECT_FLOAT_EQ(weights[6], 2.2);
1650  EXPECT_FLOAT_EQ(weights[7], 3.2);
1651  EXPECT_FLOAT_EQ(weights[8], 4.2);
1652  EXPECT_FLOAT_EQ(weights[9], 5.2);
1653 
1654  // Check TChain expansion
1655  general_options.m_datafiles = {"datafile*.root"};
1656  {
1657  MVA::ROOTDataset chain_test(general_options);
1658  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1659  }
1660  std::filesystem::copy_file("datafile.root", "datafile3.root");
1661  {
1662  MVA::ROOTDataset chain_test(general_options);
1663  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1664  }
1665  std::filesystem::copy_file("datafile.root", "datafile4.root");
1666  {
1667  MVA::ROOTDataset chain_test(general_options);
1668  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1669  }
1670  // Test m_max_events feature
1671  {
1672  general_options.m_max_events = 10;
1673  MVA::ROOTDataset chain_test(general_options);
1674  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1675  general_options.m_max_events = 0;
1676  }
1677 
1678  // If a file exists with the specified expansion
1679  // the file takes precedence over the expansion
1680  std::filesystem::copy_file("datafile.root", "datafile*.root");
1681  {
1682  general_options.m_max_events = 0;
1683  MVA::ROOTDataset chain_test(general_options);
1684  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1685  }
1686  }
1687 }
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Definition: Dataset.h:294
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
Definition: Dataset.h:349
Wraps the data of a single event into a Dataset.
Definition: Dataset.h:135
Wraps another Dataset and provides a view to a subset of its features and events.
Definition: Dataset.h:234
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:66
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.