Belle II Software  release-06-02-00
test_Dataset.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/interface/Dataset.h>
10 #include <framework/utilities/TestHelpers.h>
11 
12 #include <boost/filesystem/operations.hpp>
13 
14 #include <gtest/gtest.h>
15 
16 #include <fstream>
17 #include <numeric>
18 
19 using namespace Belle2;
20 
21 namespace {
22 
23  class TestDataset : public MVA::Dataset {
24  public:
25  explicit TestDataset(MVA::GeneralOptions& general_options) : MVA::Dataset(general_options)
26  {
27  m_input = {1.0, 2.0, 3.0, 4.0, 5.0};
28  m_target = 3.0;
29  m_isSignal = true;
30  m_weight = -3.0;
31  }
32 
33  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 5; }
34  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 2; }
35  [[nodiscard]] unsigned int getNumberOfEvents() const override { return 20; }
36  void loadEvent(unsigned int iEvent) override
37  {
38  auto f = static_cast<float>(iEvent);
39  m_input = {f + 1, f + 2, f + 3, f + 4, f + 5};
40  m_spectators = {f + 6, f + 7};
41  };
42  float getSignalFraction() override { return 0.1; };
43  std::vector<float> getFeature(unsigned int iFeature) override
44  {
45  std::vector<float> a(20, 0.0);
46  std::iota(a.begin(), a.end(), iFeature + 1);
47  return a;
48  }
49  std::vector<float> getSpectator(unsigned int iSpectator) override
50  {
51  std::vector<float> a(20, 0.0);
52  std::iota(a.begin(), a.end(), iSpectator + 6);
53  return a;
54  }
55 
56  };
57 
58  TEST(DatasetTest, SingleDataset)
59  {
60 
61  MVA::GeneralOptions general_options;
62  general_options.m_variables = {"a", "b", "c"};
63  general_options.m_spectators = {"e", "f"};
64  general_options.m_signal_class = 2;
65  std::vector<float> input = {1.0, 2.0, 3.0};
66  std::vector<float> spectators = {4.0, 5.0};
67 
68  MVA::SingleDataset x(general_options, input, 2.0, spectators);
69 
70  EXPECT_EQ(x.getNumberOfFeatures(), 3);
71  EXPECT_EQ(x.getNumberOfSpectators(), 2);
72  EXPECT_EQ(x.getNumberOfEvents(), 1);
73 
74  EXPECT_EQ(x.getFeatureIndex("a"), 0);
75  EXPECT_EQ(x.getFeatureIndex("b"), 1);
76  EXPECT_EQ(x.getFeatureIndex("c"), 2);
77  EXPECT_B2ERROR(x.getFeatureIndex("bla"));
78 
79  EXPECT_EQ(x.getSpectatorIndex("e"), 0);
80  EXPECT_EQ(x.getSpectatorIndex("f"), 1);
81  EXPECT_B2ERROR(x.getSpectatorIndex("bla"));
82 
83  // Should just work
84  x.loadEvent(0);
85 
86  EXPECT_EQ(x.m_input.size(), 3);
87  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
88  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
89  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
90 
91  EXPECT_EQ(x.m_spectators.size(), 2);
92  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.0);
93  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.0);
94 
95  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
96  EXPECT_FLOAT_EQ(x.m_target, 2.0);
97  EXPECT_EQ(x.m_isSignal, true);
98 
99  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1.0);
100 
101  auto feature = x.getFeature(1);
102  EXPECT_EQ(feature.size(), 1);
103  EXPECT_FLOAT_EQ(feature[0], 2.0);
104 
105  auto spectator = x.getSpectator(1);
106  EXPECT_EQ(spectator.size(), 1);
107  EXPECT_FLOAT_EQ(spectator[0], 5.0);
108 
109  // Same result for mother class implementation
110  feature = x.Dataset::getFeature(1);
111  EXPECT_EQ(feature.size(), 1);
112  EXPECT_FLOAT_EQ(feature[0], 2.0);
113 
114  }
115 
116  TEST(DatasetTest, MultiDataset)
117  {
118 
119  MVA::GeneralOptions general_options;
120  general_options.m_variables = {"a", "b", "c"};
121  general_options.m_spectators = {"e", "f"};
122  general_options.m_signal_class = 2;
123  std::vector<std::vector<float>> matrix = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}};
124  std::vector<std::vector<float>> spectator_matrix = {{12.0, 13.0}, {15.0, 16.0}, {18.0, 19.0}};
125  std::vector<float> targets = {2.0, 0.0, 2.0};
126  std::vector<float> weights = {1.0, 2.0, 3.0};
127 
128  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, {1.0}, weights));
129 
130  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, targets, {1.0}));
131 
132  MVA::MultiDataset x(general_options, matrix, spectator_matrix, targets, weights);
133 
134  EXPECT_EQ(x.getNumberOfFeatures(), 3);
135  EXPECT_EQ(x.getNumberOfEvents(), 3);
136 
137  // Should just work
138  x.loadEvent(0);
139 
140  EXPECT_EQ(x.m_input.size(), 3);
141  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
142  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
143  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
144 
145  EXPECT_EQ(x.m_spectators.size(), 2);
146  EXPECT_FLOAT_EQ(x.m_spectators[0], 12.0);
147  EXPECT_FLOAT_EQ(x.m_spectators[1], 13.0);
148 
149  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
150  EXPECT_FLOAT_EQ(x.m_target, 2.0);
151  EXPECT_EQ(x.m_isSignal, true);
152 
153  // Should just work
154  x.loadEvent(1);
155 
156  EXPECT_EQ(x.m_input.size(), 3);
157  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
158  EXPECT_FLOAT_EQ(x.m_input[1], 5.0);
159  EXPECT_FLOAT_EQ(x.m_input[2], 6.0);
160 
161  EXPECT_EQ(x.m_spectators.size(), 2);
162  EXPECT_FLOAT_EQ(x.m_spectators[0], 15.0);
163  EXPECT_FLOAT_EQ(x.m_spectators[1], 16.0);
164 
165  EXPECT_FLOAT_EQ(x.m_weight, 2.0);
166  EXPECT_FLOAT_EQ(x.m_target, 0.0);
167  EXPECT_EQ(x.m_isSignal, false);
168 
169  // Should just work
170  x.loadEvent(2);
171 
172  EXPECT_EQ(x.m_input.size(), 3);
173  EXPECT_FLOAT_EQ(x.m_input[0], 7.0);
174  EXPECT_FLOAT_EQ(x.m_input[1], 8.0);
175  EXPECT_FLOAT_EQ(x.m_input[2], 9.0);
176 
177  EXPECT_EQ(x.m_spectators.size(), 2);
178  EXPECT_FLOAT_EQ(x.m_spectators[0], 18.0);
179  EXPECT_FLOAT_EQ(x.m_spectators[1], 19.0);
180 
181  EXPECT_FLOAT_EQ(x.m_weight, 3.0);
182  EXPECT_FLOAT_EQ(x.m_target, 2.0);
183  EXPECT_EQ(x.m_isSignal, true);
184 
185  EXPECT_FLOAT_EQ(x.getSignalFraction(), 4.0 / 6.0);
186 
187  auto feature = x.getFeature(1);
188  EXPECT_EQ(feature.size(), 3);
189  EXPECT_FLOAT_EQ(feature[0], 2.0);
190  EXPECT_FLOAT_EQ(feature[1], 5.0);
191  EXPECT_FLOAT_EQ(feature[2], 8.0);
192 
193  auto spectator = x.getSpectator(1);
194  EXPECT_EQ(spectator.size(), 3);
195  EXPECT_FLOAT_EQ(spectator[0], 13.0);
196  EXPECT_FLOAT_EQ(spectator[1], 16.0);
197  EXPECT_FLOAT_EQ(spectator[2], 19.0);
198 
199 
200  }
201 
202  TEST(DatasetTest, SubDataset)
203  {
204 
205  MVA::GeneralOptions general_options;
206  general_options.m_signal_class = 3;
207  general_options.m_variables = {"a", "b", "c", "d", "e"};
208  general_options.m_spectators = {"f", "g"};
209  TestDataset test_dataset(general_options);
210 
211  general_options.m_variables = {"a", "d", "e"};
212  general_options.m_spectators = {"g"};
213  std::vector<bool> events = {true, false, true, false, true, false, true, false, true, false,
214  true, false, true, false, true, false, true, false, true, false
215  };
216  MVA::SubDataset x(general_options, events, test_dataset);
217 
218  EXPECT_EQ(x.getNumberOfFeatures(), 3);
219  EXPECT_EQ(x.getNumberOfEvents(), 10);
220 
221  // Should just work
222  x.loadEvent(0);
223 
224  EXPECT_EQ(x.m_input.size(), 3);
225  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
226  EXPECT_FLOAT_EQ(x.m_input[1], 4.0);
227  EXPECT_FLOAT_EQ(x.m_input[2], 5.0);
228 
229  EXPECT_EQ(x.m_spectators.size(), 1);
230  EXPECT_FLOAT_EQ(x.m_spectators[0], 7.0);
231 
232  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
233  EXPECT_FLOAT_EQ(x.m_target, 3.0);
234  EXPECT_EQ(x.m_isSignal, true);
235 
236  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1);
237 
238  auto feature = x.getFeature(1);
239  EXPECT_EQ(feature.size(), 10);
240  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
241  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
242  };
243 
244  auto spectator = x.getSpectator(0);
245  EXPECT_EQ(spectator.size(), 10);
246  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
247  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
248  };
249 
250  // Same result for mother class implementation
251  feature = x.Dataset::getFeature(1);
252  EXPECT_EQ(feature.size(), 10);
253  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
254  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
255  };
256 
257  spectator = x.Dataset::getSpectator(0);
258  EXPECT_EQ(spectator.size(), 10);
259  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
260  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
261  };
262 
263  // Test without event indices
264  MVA::SubDataset y(general_options, {}, test_dataset);
265  feature = y.getFeature(1);
266  EXPECT_EQ(feature.size(), 20);
267  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
268  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
269  };
270 
271  spectator = y.getSpectator(0);
272  EXPECT_EQ(spectator.size(), 20);
273  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
274  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
275  };
276 
277  // Same result for mother class implementation
278  feature = y.Dataset::getFeature(1);
279  EXPECT_EQ(feature.size(), 20);
280  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
281  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
282  };
283 
284  spectator = y.Dataset::getSpectator(0);
285  EXPECT_EQ(spectator.size(), 20);
286  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
287  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
288  };
289 
290  general_options.m_variables = {"a", "d", "e", "DOESNOTEXIST"};
291  try {
292  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
293  } catch (...) {
294 
295  }
296  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
297 
298  general_options.m_variables = {"a", "d", "e"};
299  general_options.m_spectators = {"DOESNOTEXIST"};
300  try {
301  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
302  } catch (...) {
303 
304  }
305  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
306 
307  }
308 
309  TEST(DatasetTest, CombinedDataset)
310  {
311 
312  MVA::GeneralOptions general_options;
313  general_options.m_signal_class = 1;
314  general_options.m_variables = {"a", "b", "c", "d", "e"};
315  general_options.m_spectators = {"f", "g"};
316  TestDataset signal_dataset(general_options);
317  TestDataset bckgrd_dataset(general_options);
318 
319  MVA::CombinedDataset x(general_options, signal_dataset, bckgrd_dataset);
320 
321  EXPECT_EQ(x.getNumberOfFeatures(), 5);
322  EXPECT_EQ(x.getNumberOfEvents(), 40);
323 
324  // Should just work
325  x.loadEvent(0);
326 
327  EXPECT_EQ(x.m_input.size(), 5);
328  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
329  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
330  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
331  EXPECT_FLOAT_EQ(x.m_input[3], 4.0);
332  EXPECT_FLOAT_EQ(x.m_input[4], 5.0);
333 
334  EXPECT_EQ(x.m_spectators.size(), 2);
335  EXPECT_FLOAT_EQ(x.m_spectators[0], 6.0);
336  EXPECT_FLOAT_EQ(x.m_spectators[1], 7.0);
337 
338  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
339  EXPECT_FLOAT_EQ(x.m_target, 1.0);
340  EXPECT_EQ(x.m_isSignal, true);
341 
342  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.5);
343 
344  auto feature = x.getFeature(1);
345  EXPECT_EQ(feature.size(), 40);
346  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
347  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
348  };
349 
350  auto spectator = x.getSpectator(0);
351  EXPECT_EQ(spectator.size(), 40);
352  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
353  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
354  };
355 
356  // Same result for mother class implementation
357  feature = x.Dataset::getFeature(1);
358  EXPECT_EQ(feature.size(), 40);
359  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
360  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
361  };
362 
363  spectator = x.Dataset::getSpectator(0);
364  EXPECT_EQ(spectator.size(), 40);
365  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
366  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
367  };
368 
369  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
370  x.loadEvent(iEvent);
371  EXPECT_EQ(x.m_isSignal, iEvent < 20);
372  EXPECT_FLOAT_EQ(x.m_target, iEvent < 20 ? 1.0 : 0.0);
373  }
374 
375  }
376 
377  TEST(DatasetTest, ROOTDataset)
378  {
379 
381  TFile file("datafile.root", "RECREATE");
382  file.cd();
383  TTree tree("tree", "TreeTitle");
384  float a, b, c, d, e, f, g, v, w = 0;
385  tree.Branch("a", &a);
386  tree.Branch("b", &b);
387  tree.Branch("c", &c);
388  tree.Branch("d", &d);
389  tree.Branch("e__bo__bc", &e);
390  tree.Branch("f__bo__bc", &f);
391  tree.Branch("g", &g);
392  tree.Branch("__weight__", &c);
393  tree.Branch("v__bo__bc", &v);
394  tree.Branch("w", &w);
395 
396  for (unsigned int i = 0; i < 5; ++i) {
397  a = i + 1.0;
398  b = i + 1.1;
399  c = i + 1.2;
400  d = i + 1.3;
401  e = i + 1.4;
402  f = i + 1.5;
403  g = float(i % 2 == 0);
404  w = i + 1.6;
405  v = i + 1.7;
406  tree.Fill();
407  }
408 
409  file.Write("tree");
410 
411  MVA::GeneralOptions general_options;
412  // Both names with and without makeROOTCompatible should work
413  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
414  general_options.m_spectators = {"w", "v()"};
415  general_options.m_signal_class = 1;
416  general_options.m_datafiles = {"datafile.root"};
417  general_options.m_treename = "tree";
418  general_options.m_target_variable = "g";
419  general_options.m_weight_variable = "c";
420  MVA::ROOTDataset x(general_options);
421 
422  EXPECT_EQ(x.getNumberOfFeatures(), 4);
423  EXPECT_EQ(x.getNumberOfSpectators(), 2);
424  EXPECT_EQ(x.getNumberOfEvents(), 5);
425 
426  // Should just work
427  x.loadEvent(0);
428  EXPECT_EQ(x.m_input.size(), 4);
429  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
430  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
431  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
432  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
433  EXPECT_EQ(x.m_spectators.size(), 2);
434  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
435  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
436  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
437  EXPECT_FLOAT_EQ(x.m_target, 1.0);
438  EXPECT_EQ(x.m_isSignal, true);
439 
440  x.loadEvent(1);
441  EXPECT_EQ(x.m_input.size(), 4);
442  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
443  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
444  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
445  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
446  EXPECT_EQ(x.m_spectators.size(), 2);
447  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
448  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
449  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
450  EXPECT_FLOAT_EQ(x.m_target, 0.0);
451  EXPECT_EQ(x.m_isSignal, false);
452 
453  x.loadEvent(2);
454  EXPECT_EQ(x.m_input.size(), 4);
455  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
456  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
457  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
458  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
459  EXPECT_EQ(x.m_spectators.size(), 2);
460  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
461  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
462  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
463  EXPECT_FLOAT_EQ(x.m_target, 1.0);
464  EXPECT_EQ(x.m_isSignal, true);
465 
466  x.loadEvent(3);
467  EXPECT_EQ(x.m_input.size(), 4);
468  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
469  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
470  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
471  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
472  EXPECT_EQ(x.m_spectators.size(), 2);
473  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
474  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
475  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
476  EXPECT_FLOAT_EQ(x.m_target, 0.0);
477  EXPECT_EQ(x.m_isSignal, false);
478 
479  x.loadEvent(4);
480  EXPECT_EQ(x.m_input.size(), 4);
481  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
482  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
483  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
484  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
485  EXPECT_EQ(x.m_spectators.size(), 2);
486  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
487  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
488  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
489  EXPECT_FLOAT_EQ(x.m_target, 1.0);
490  EXPECT_EQ(x.m_isSignal, true);
491 
492  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
493 
494  auto feature = x.getFeature(1);
495  EXPECT_EQ(feature.size(), 5);
496  EXPECT_FLOAT_EQ(feature[0], 1.1);
497  EXPECT_FLOAT_EQ(feature[1], 2.1);
498  EXPECT_FLOAT_EQ(feature[2], 3.1);
499  EXPECT_FLOAT_EQ(feature[3], 4.1);
500  EXPECT_FLOAT_EQ(feature[4], 5.1);
501 
502  // Same result for mother class implementation
503  feature = x.Dataset::getFeature(1);
504  EXPECT_EQ(feature.size(), 5);
505  EXPECT_FLOAT_EQ(feature[0], 1.1);
506  EXPECT_FLOAT_EQ(feature[1], 2.1);
507  EXPECT_FLOAT_EQ(feature[2], 3.1);
508  EXPECT_FLOAT_EQ(feature[3], 4.1);
509  EXPECT_FLOAT_EQ(feature[4], 5.1);
510 
511  auto spectator = x.getSpectator(1);
512  EXPECT_EQ(spectator.size(), 5);
513  EXPECT_FLOAT_EQ(spectator[0], 1.7);
514  EXPECT_FLOAT_EQ(spectator[1], 2.7);
515  EXPECT_FLOAT_EQ(spectator[2], 3.7);
516  EXPECT_FLOAT_EQ(spectator[3], 4.7);
517  EXPECT_FLOAT_EQ(spectator[4], 5.7);
518 
519  // Same result for mother class implementation
520  spectator = x.Dataset::getSpectator(1);
521  EXPECT_EQ(spectator.size(), 5);
522  EXPECT_FLOAT_EQ(spectator[0], 1.7);
523  EXPECT_FLOAT_EQ(spectator[1], 2.7);
524  EXPECT_FLOAT_EQ(spectator[2], 3.7);
525  EXPECT_FLOAT_EQ(spectator[3], 4.7);
526  EXPECT_FLOAT_EQ(spectator[4], 5.7);
527 
528  auto weights = x.getWeights();
529  EXPECT_EQ(weights.size(), 5);
530  EXPECT_FLOAT_EQ(weights[0], 1.2);
531  EXPECT_FLOAT_EQ(weights[1], 2.2);
532  EXPECT_FLOAT_EQ(weights[2], 3.2);
533  EXPECT_FLOAT_EQ(weights[3], 4.2);
534  EXPECT_FLOAT_EQ(weights[4], 5.2);
535 
536  // Same result for mother class implementation
537  weights = x.Dataset::getWeights();
538  EXPECT_EQ(weights.size(), 5);
539  EXPECT_FLOAT_EQ(weights[0], 1.2);
540  EXPECT_FLOAT_EQ(weights[1], 2.2);
541  EXPECT_FLOAT_EQ(weights[2], 3.2);
542  EXPECT_FLOAT_EQ(weights[3], 4.2);
543  EXPECT_FLOAT_EQ(weights[4], 5.2);
544 
545  auto targets = x.getTargets();
546  EXPECT_EQ(targets.size(), 5);
547  EXPECT_FLOAT_EQ(targets[0], 1.0);
548  EXPECT_FLOAT_EQ(targets[1], 0.0);
549  EXPECT_FLOAT_EQ(targets[2], 1.0);
550  EXPECT_FLOAT_EQ(targets[3], 0.0);
551  EXPECT_FLOAT_EQ(targets[4], 1.0);
552 
553  auto signals = x.getSignals();
554  EXPECT_EQ(signals.size(), 5);
555  EXPECT_EQ(signals[0], true);
556  EXPECT_EQ(signals[1], false);
557  EXPECT_EQ(signals[2], true);
558  EXPECT_EQ(signals[3], false);
559  EXPECT_EQ(signals[4], true);
560 
561  // Using __weight__ should work as well,
562  // the only difference to using _weight__ instead of g is
563  // in setBranchAddresses which avoids calling makeROOTCompatible
564  // So we have to check the behaviour using __weight__ as well
565  general_options.m_weight_variable = "__weight__";
566  MVA::ROOTDataset y(general_options);
567 
568  weights = y.getWeights();
569  EXPECT_EQ(weights.size(), 5);
570  EXPECT_FLOAT_EQ(weights[0], 1.2);
571  EXPECT_FLOAT_EQ(weights[1], 2.2);
572  EXPECT_FLOAT_EQ(weights[2], 3.2);
573  EXPECT_FLOAT_EQ(weights[3], 4.2);
574  EXPECT_FLOAT_EQ(weights[4], 5.2);
575 
576  // Check TChain expansion
577  general_options.m_datafiles = {"datafile*.root"};
578  {
579  MVA::ROOTDataset chain_test(general_options);
580  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
581  }
582  boost::filesystem::copy_file("datafile.root", "datafile2.root");
583  {
584  MVA::ROOTDataset chain_test(general_options);
585  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
586  }
587  boost::filesystem::copy_file("datafile.root", "datafile3.root");
588  {
589  MVA::ROOTDataset chain_test(general_options);
590  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
591  }
592  // Test m_max_events feature
593  {
594  general_options.m_max_events = 10;
595  MVA::ROOTDataset chain_test(general_options);
596  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
597  general_options.m_max_events = 0;
598  }
599 
600  // Check for missing tree
601  general_options.m_treename = "missing tree";
602  try {
603  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
604  } catch (...) {
605 
606  }
607  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
608 
609  // Check for missing branch
610  general_options.m_treename = "tree";
611  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
612  try {
613  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
614  } catch (...) {
615 
616  }
617  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
618 
619  // Check for missing branch
620  general_options.m_treename = "tree";
621  general_options.m_variables = {"a", "b", "e", "f"};
622  general_options.m_spectators = {"missing branch"};
623  try {
624  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
625  } catch (...) {
626 
627  }
628  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
629 
630  // Check for missing file
631  general_options.m_spectators = {};
632  general_options.m_datafiles = {"DOESNOTEXIST.root"};
633  general_options.m_treename = "tree";
634  try {
635  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
636  } catch (...) {
637 
638  }
639  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
640 
641  // Check for invalid file
642  general_options.m_datafiles = {"ISNotAValidROOTFile"};
643  general_options.m_treename = "tree";
644 
645  {
646  std::ofstream(general_options.m_datafiles[0]);
647  }
648  EXPECT_TRUE(boost::filesystem::exists(general_options.m_datafiles[0]));
649 
650  try {
651  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
652  } catch (...) {
653 
654  }
655  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
656  }
657 
658 
659  TEST(DatasetTest, ROOTDatasetDouble)
660  {
661 
663  TFile file("datafile.root", "RECREATE");
664  file.cd();
665  TTree tree("tree", "TreeTitle");
666  double a, b, c, d, e, f, g, v, w = 0;
667  tree.Branch("a", &a, "a/D");
668  tree.Branch("b", &b, "b/D");
669  tree.Branch("c", &c, "c/D");
670  tree.Branch("d", &d, "d/D");
671  tree.Branch("e__bo__bc", &e, "e__bo__bc/D");
672  tree.Branch("f__bo__bc", &f, "f__bo__bc/D");
673  tree.Branch("g", &g, "g/D");
674  tree.Branch("__weight__", &c, "__weight__/D");
675  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
676  tree.Branch("w", &w, "w/D");
677 
678  for (unsigned int i = 0; i < 5; ++i) {
679  a = i + 1.0;
680  b = i + 1.1;
681  c = i + 1.2;
682  d = i + 1.3;
683  e = i + 1.4;
684  f = i + 1.5;
685  g = float(i % 2 == 0);
686  w = i + 1.6;
687  v = i + 1.7;
688  tree.Fill();
689  }
690 
691  file.Write("tree");
692 
693  MVA::GeneralOptions general_options;
694  // Both names with and without makeROOTCompatible should work
695  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
696  general_options.m_spectators = {"w", "v()"};
697  general_options.m_signal_class = 1;
698  general_options.m_datafiles = {"datafile.root"};
699  general_options.m_treename = "tree";
700  general_options.m_target_variable = "g";
701  general_options.m_weight_variable = "c";
702  MVA::ROOTDataset x(general_options);
703 
704  EXPECT_EQ(x.getNumberOfFeatures(), 4);
705  EXPECT_EQ(x.getNumberOfSpectators(), 2);
706  EXPECT_EQ(x.getNumberOfEvents(), 5);
707 
708  // Should just work
709  x.loadEvent(0);
710  EXPECT_EQ(x.m_input.size(), 4);
711  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
712  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
713  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
714  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
715  EXPECT_EQ(x.m_spectators.size(), 2);
716  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
717  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
718  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
719  EXPECT_FLOAT_EQ(x.m_target, 1.0);
720  EXPECT_EQ(x.m_isSignal, true);
721 
722  x.loadEvent(1);
723  EXPECT_EQ(x.m_input.size(), 4);
724  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
725  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
726  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
727  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
728  EXPECT_EQ(x.m_spectators.size(), 2);
729  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
730  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
731  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
732  EXPECT_FLOAT_EQ(x.m_target, 0.0);
733  EXPECT_EQ(x.m_isSignal, false);
734 
735  x.loadEvent(2);
736  EXPECT_EQ(x.m_input.size(), 4);
737  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
738  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
739  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
740  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
741  EXPECT_EQ(x.m_spectators.size(), 2);
742  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
743  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
744  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
745  EXPECT_FLOAT_EQ(x.m_target, 1.0);
746  EXPECT_EQ(x.m_isSignal, true);
747 
748  x.loadEvent(3);
749  EXPECT_EQ(x.m_input.size(), 4);
750  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
751  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
752  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
753  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
754  EXPECT_EQ(x.m_spectators.size(), 2);
755  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
756  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
757  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
758  EXPECT_FLOAT_EQ(x.m_target, 0.0);
759  EXPECT_EQ(x.m_isSignal, false);
760 
761  x.loadEvent(4);
762  EXPECT_EQ(x.m_input.size(), 4);
763  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
764  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
765  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
766  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
767  EXPECT_EQ(x.m_spectators.size(), 2);
768  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
769  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
770  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
771  EXPECT_FLOAT_EQ(x.m_target, 1.0);
772  EXPECT_EQ(x.m_isSignal, true);
773 
774  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
775 
776  auto feature = x.getFeature(1);
777  EXPECT_EQ(feature.size(), 5);
778  EXPECT_FLOAT_EQ(feature[0], 1.1);
779  EXPECT_FLOAT_EQ(feature[1], 2.1);
780  EXPECT_FLOAT_EQ(feature[2], 3.1);
781  EXPECT_FLOAT_EQ(feature[3], 4.1);
782  EXPECT_FLOAT_EQ(feature[4], 5.1);
783 
784  // Same result for mother class implementation
785  feature = x.Dataset::getFeature(1);
786  EXPECT_EQ(feature.size(), 5);
787  EXPECT_FLOAT_EQ(feature[0], 1.1);
788  EXPECT_FLOAT_EQ(feature[1], 2.1);
789  EXPECT_FLOAT_EQ(feature[2], 3.1);
790  EXPECT_FLOAT_EQ(feature[3], 4.1);
791  EXPECT_FLOAT_EQ(feature[4], 5.1);
792 
793  auto spectator = x.getSpectator(1);
794  EXPECT_EQ(spectator.size(), 5);
795  EXPECT_FLOAT_EQ(spectator[0], 1.7);
796  EXPECT_FLOAT_EQ(spectator[1], 2.7);
797  EXPECT_FLOAT_EQ(spectator[2], 3.7);
798  EXPECT_FLOAT_EQ(spectator[3], 4.7);
799  EXPECT_FLOAT_EQ(spectator[4], 5.7);
800 
801  // Same result for mother class implementation
802  spectator = x.Dataset::getSpectator(1);
803  EXPECT_EQ(spectator.size(), 5);
804  EXPECT_FLOAT_EQ(spectator[0], 1.7);
805  EXPECT_FLOAT_EQ(spectator[1], 2.7);
806  EXPECT_FLOAT_EQ(spectator[2], 3.7);
807  EXPECT_FLOAT_EQ(spectator[3], 4.7);
808  EXPECT_FLOAT_EQ(spectator[4], 5.7);
809 
810  auto weights = x.getWeights();
811  EXPECT_EQ(weights.size(), 5);
812  EXPECT_FLOAT_EQ(weights[0], 1.2);
813  EXPECT_FLOAT_EQ(weights[1], 2.2);
814  EXPECT_FLOAT_EQ(weights[2], 3.2);
815  EXPECT_FLOAT_EQ(weights[3], 4.2);
816  EXPECT_FLOAT_EQ(weights[4], 5.2);
817 
818  // Same result for mother class implementation
819  weights = x.Dataset::getWeights();
820  EXPECT_EQ(weights.size(), 5);
821  EXPECT_FLOAT_EQ(weights[0], 1.2);
822  EXPECT_FLOAT_EQ(weights[1], 2.2);
823  EXPECT_FLOAT_EQ(weights[2], 3.2);
824  EXPECT_FLOAT_EQ(weights[3], 4.2);
825  EXPECT_FLOAT_EQ(weights[4], 5.2);
826 
827  auto targets = x.getTargets();
828  EXPECT_EQ(targets.size(), 5);
829  EXPECT_FLOAT_EQ(targets[0], 1.0);
830  EXPECT_FLOAT_EQ(targets[1], 0.0);
831  EXPECT_FLOAT_EQ(targets[2], 1.0);
832  EXPECT_FLOAT_EQ(targets[3], 0.0);
833  EXPECT_FLOAT_EQ(targets[4], 1.0);
834 
835  auto signals = x.getSignals();
836  EXPECT_EQ(signals.size(), 5);
837  EXPECT_EQ(signals[0], true);
838  EXPECT_EQ(signals[1], false);
839  EXPECT_EQ(signals[2], true);
840  EXPECT_EQ(signals[3], false);
841  EXPECT_EQ(signals[4], true);
842 
843  // Using __weight__ should work as well,
844  // the only difference to using _weight__ instead of g is
845  // in setBranchAddresses which avoids calling makeROOTCompatible
846  // So we have to check the behaviour using __weight__ as well
847  general_options.m_weight_variable = "__weight__";
848  MVA::ROOTDataset y(general_options);
849 
850  weights = y.getWeights();
851  EXPECT_EQ(weights.size(), 5);
852  EXPECT_FLOAT_EQ(weights[0], 1.2);
853  EXPECT_FLOAT_EQ(weights[1], 2.2);
854  EXPECT_FLOAT_EQ(weights[2], 3.2);
855  EXPECT_FLOAT_EQ(weights[3], 4.2);
856  EXPECT_FLOAT_EQ(weights[4], 5.2);
857 
858  // Check TChain expansion
859  general_options.m_datafiles = {"datafile*.root"};
860  {
861  MVA::ROOTDataset chain_test(general_options);
862  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
863  }
864  boost::filesystem::copy_file("datafile.root", "datafile2.root");
865  {
866  MVA::ROOTDataset chain_test(general_options);
867  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
868  }
869  boost::filesystem::copy_file("datafile.root", "datafile3.root");
870  {
871  MVA::ROOTDataset chain_test(general_options);
872  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
873  }
874  // Test m_max_events feature
875  {
876  general_options.m_max_events = 10;
877  MVA::ROOTDataset chain_test(general_options);
878  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
879  general_options.m_max_events = 0;
880  }
881 
882  // Check for missing tree
883  general_options.m_treename = "missing tree";
884  try {
885  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
886  } catch (...) {
887 
888  }
889  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
890 
891  // Check for missing branch
892  general_options.m_treename = "tree";
893  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
894  try {
895  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
896  } catch (...) {
897 
898  }
899  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
900 
901  // Check for missing branch
902  general_options.m_treename = "tree";
903  general_options.m_variables = {"a", "b", "e", "f"};
904  general_options.m_spectators = {"missing branch"};
905  try {
906  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
907  } catch (...) {
908 
909  }
910  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
911 
912  // Check for missing file
913  general_options.m_spectators = {};
914  general_options.m_datafiles = {"DOESNOTEXIST.root"};
915  general_options.m_treename = "tree";
916  try {
917  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
918  } catch (...) {
919 
920  }
921  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
922 
923  // Check for invalid file
924  general_options.m_datafiles = {"ISNotAValidROOTFile"};
925  general_options.m_treename = "tree";
926 
927  {
928  std::ofstream(general_options.m_datafiles[0]);
929  }
930  EXPECT_TRUE(boost::filesystem::exists(general_options.m_datafiles[0]));
931 
932  try {
933  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
934  } catch (...) {
935 
936  }
937  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
938  }
939 
940  TEST(DatasetTest, ROOTMultiDataset)
941  {
942 
944  TFile file("datafile.root", "RECREATE");
945  file.cd();
946  TTree tree("tree", "TreeTitle");
947  float a, b, c, d, e, f, g, v, w = 0;
948  tree.Branch("a", &a);
949  tree.Branch("b", &b);
950  tree.Branch("c", &c);
951  tree.Branch("d", &d);
952  tree.Branch("e__bo__bc", &e);
953  tree.Branch("f__bo__bc", &f);
954  tree.Branch("g", &g);
955  tree.Branch("__weight__", &c);
956  tree.Branch("v__bo__bc", &v);
957  tree.Branch("w", &w);
958 
959  for (unsigned int i = 0; i < 5; ++i) {
960  a = i + 1.0;
961  b = i + 1.1;
962  c = i + 1.2;
963  d = i + 1.3;
964  e = i + 1.4;
965  f = i + 1.5;
966  g = float(i % 2 == 0);
967  w = i + 1.6;
968  v = i + 1.7;
969  tree.Fill();
970  }
971 
972  file.Write("tree");
973 
974  TFile file2("datafile2.root", "RECREATE");
975  file2.cd();
976  TTree tree2("tree", "TreeTitle");
977  tree2.Branch("a", &a);
978  tree2.Branch("b", &b);
979  tree2.Branch("c", &c);
980  tree2.Branch("d", &d);
981  tree2.Branch("e__bo__bc", &e);
982  tree2.Branch("f__bo__bc", &f);
983  tree2.Branch("g", &g);
984  tree2.Branch("__weight__", &c);
985  tree2.Branch("v__bo__bc", &v);
986  tree2.Branch("w", &w);
987 
988  for (unsigned int i = 0; i < 5; ++i) {
989  a = i + 1.0;
990  b = i + 1.1;
991  c = i + 1.2;
992  d = i + 1.3;
993  e = i + 1.4;
994  f = i + 1.5;
995  g = float(i % 2 == 0);
996  w = i + 1.6;
997  v = i + 1.7;
998  tree2.Fill();
999  }
1000 
1001  file2.Write("tree");
1002 
1003  MVA::GeneralOptions general_options;
1004  // Both names with and without makeROOTCompatible should work
1005  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1006  general_options.m_spectators = {"w", "v()"};
1007  general_options.m_signal_class = 1;
1008  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1009  general_options.m_treename = "tree";
1010  general_options.m_target_variable = "g";
1011  general_options.m_weight_variable = "c";
1012  MVA::ROOTDataset x(general_options);
1013 
1014  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1015  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1016  EXPECT_EQ(x.getNumberOfEvents(), 10);
1017 
1018  // Should just work
1019  x.loadEvent(0);
1020  EXPECT_EQ(x.m_input.size(), 4);
1021  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1022  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1023  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1024  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1025  EXPECT_EQ(x.m_spectators.size(), 2);
1026  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1027  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1028  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1029  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1030  EXPECT_EQ(x.m_isSignal, true);
1031 
1032  x.loadEvent(5);
1033  EXPECT_EQ(x.m_input.size(), 4);
1034  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1035  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1036  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1037  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1038  EXPECT_EQ(x.m_spectators.size(), 2);
1039  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1040  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1041  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1042  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1043  EXPECT_EQ(x.m_isSignal, true);
1044 
1045  x.loadEvent(1);
1046  EXPECT_EQ(x.m_input.size(), 4);
1047  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1048  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1049  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1050  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1051  EXPECT_EQ(x.m_spectators.size(), 2);
1052  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1053  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1054  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1055  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1056  EXPECT_EQ(x.m_isSignal, false);
1057 
1058  x.loadEvent(6);
1059  EXPECT_EQ(x.m_input.size(), 4);
1060  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1061  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1062  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1063  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1064  EXPECT_EQ(x.m_spectators.size(), 2);
1065  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1066  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1067  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1068  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1069  EXPECT_EQ(x.m_isSignal, false);
1070 
1071  x.loadEvent(2);
1072  EXPECT_EQ(x.m_input.size(), 4);
1073  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1074  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1075  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1076  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1077  EXPECT_EQ(x.m_spectators.size(), 2);
1078  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1079  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1080  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1081  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1082  EXPECT_EQ(x.m_isSignal, true);
1083 
1084  x.loadEvent(7);
1085  EXPECT_EQ(x.m_input.size(), 4);
1086  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1087  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1088  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1089  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1090  EXPECT_EQ(x.m_spectators.size(), 2);
1091  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1092  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1093  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1094  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1095  EXPECT_EQ(x.m_isSignal, true);
1096 
1097  x.loadEvent(3);
1098  EXPECT_EQ(x.m_input.size(), 4);
1099  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1100  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1101  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1102  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1103  EXPECT_EQ(x.m_spectators.size(), 2);
1104  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1105  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1106  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1107  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1108  EXPECT_EQ(x.m_isSignal, false);
1109 
1110  x.loadEvent(8);
1111  EXPECT_EQ(x.m_input.size(), 4);
1112  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1113  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1114  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1115  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1116  EXPECT_EQ(x.m_spectators.size(), 2);
1117  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1118  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1119  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1120  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1121  EXPECT_EQ(x.m_isSignal, false);
1122 
1123  x.loadEvent(4);
1124  EXPECT_EQ(x.m_input.size(), 4);
1125  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1126  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1127  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1128  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1129  EXPECT_EQ(x.m_spectators.size(), 2);
1130  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1131  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1132  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1133  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1134  EXPECT_EQ(x.m_isSignal, true);
1135 
1136  x.loadEvent(9);
1137  EXPECT_EQ(x.m_input.size(), 4);
1138  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1139  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1140  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1141  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1142  EXPECT_EQ(x.m_spectators.size(), 2);
1143  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1144  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1145  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1146  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1147  EXPECT_EQ(x.m_isSignal, true);
1148 
1149  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1150 
1151  auto feature = x.getFeature(1);
1152  EXPECT_EQ(feature.size(), 10);
1153  EXPECT_FLOAT_EQ(feature[0], 1.1);
1154  EXPECT_FLOAT_EQ(feature[1], 2.1);
1155  EXPECT_FLOAT_EQ(feature[2], 3.1);
1156  EXPECT_FLOAT_EQ(feature[3], 4.1);
1157  EXPECT_FLOAT_EQ(feature[4], 5.1);
1158  EXPECT_FLOAT_EQ(feature[5], 1.1);
1159  EXPECT_FLOAT_EQ(feature[6], 2.1);
1160  EXPECT_FLOAT_EQ(feature[7], 3.1);
1161  EXPECT_FLOAT_EQ(feature[8], 4.1);
1162  EXPECT_FLOAT_EQ(feature[9], 5.1);
1163 
1164  // Same result for mother class implementation
1165  feature = x.Dataset::getFeature(1);
1166  EXPECT_EQ(feature.size(), 10);
1167  EXPECT_FLOAT_EQ(feature[0], 1.1);
1168  EXPECT_FLOAT_EQ(feature[1], 2.1);
1169  EXPECT_FLOAT_EQ(feature[2], 3.1);
1170  EXPECT_FLOAT_EQ(feature[3], 4.1);
1171  EXPECT_FLOAT_EQ(feature[4], 5.1);
1172  EXPECT_FLOAT_EQ(feature[5], 1.1);
1173  EXPECT_FLOAT_EQ(feature[6], 2.1);
1174  EXPECT_FLOAT_EQ(feature[7], 3.1);
1175  EXPECT_FLOAT_EQ(feature[8], 4.1);
1176  EXPECT_FLOAT_EQ(feature[9], 5.1);
1177 
1178  auto spectator = x.getSpectator(1);
1179  EXPECT_EQ(spectator.size(), 10);
1180  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1181  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1182  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1183  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1184  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1185  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1186  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1187  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1188  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1189  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1190 
1191  // Same result for mother class implementation
1192  spectator = x.Dataset::getSpectator(1);
1193  EXPECT_EQ(spectator.size(), 10);
1194  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1195  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1196  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1197  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1198  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1199  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1200  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1201  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1202  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1203  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1204 
1205  auto weights = x.getWeights();
1206  EXPECT_EQ(weights.size(), 10);
1207  EXPECT_FLOAT_EQ(weights[0], 1.2);
1208  EXPECT_FLOAT_EQ(weights[1], 2.2);
1209  EXPECT_FLOAT_EQ(weights[2], 3.2);
1210  EXPECT_FLOAT_EQ(weights[3], 4.2);
1211  EXPECT_FLOAT_EQ(weights[4], 5.2);
1212  EXPECT_FLOAT_EQ(weights[5], 1.2);
1213  EXPECT_FLOAT_EQ(weights[6], 2.2);
1214  EXPECT_FLOAT_EQ(weights[7], 3.2);
1215  EXPECT_FLOAT_EQ(weights[8], 4.2);
1216  EXPECT_FLOAT_EQ(weights[9], 5.2);
1217 
1218  // Same result for mother class implementation
1219  weights = x.Dataset::getWeights();
1220  EXPECT_EQ(weights.size(), 10);
1221  EXPECT_FLOAT_EQ(weights[0], 1.2);
1222  EXPECT_FLOAT_EQ(weights[1], 2.2);
1223  EXPECT_FLOAT_EQ(weights[2], 3.2);
1224  EXPECT_FLOAT_EQ(weights[3], 4.2);
1225  EXPECT_FLOAT_EQ(weights[4], 5.2);
1226  EXPECT_FLOAT_EQ(weights[5], 1.2);
1227  EXPECT_FLOAT_EQ(weights[6], 2.2);
1228  EXPECT_FLOAT_EQ(weights[7], 3.2);
1229  EXPECT_FLOAT_EQ(weights[8], 4.2);
1230  EXPECT_FLOAT_EQ(weights[9], 5.2);
1231 
1232  auto targets = x.getTargets();
1233  EXPECT_EQ(targets.size(), 10);
1234  EXPECT_FLOAT_EQ(targets[0], 1.0);
1235  EXPECT_FLOAT_EQ(targets[1], 0.0);
1236  EXPECT_FLOAT_EQ(targets[2], 1.0);
1237  EXPECT_FLOAT_EQ(targets[3], 0.0);
1238  EXPECT_FLOAT_EQ(targets[4], 1.0);
1239  EXPECT_FLOAT_EQ(targets[5], 1.0);
1240  EXPECT_FLOAT_EQ(targets[6], 0.0);
1241  EXPECT_FLOAT_EQ(targets[7], 1.0);
1242  EXPECT_FLOAT_EQ(targets[8], 0.0);
1243  EXPECT_FLOAT_EQ(targets[9], 1.0);
1244 
1245  auto signals = x.getSignals();
1246  EXPECT_EQ(signals.size(), 10);
1247  EXPECT_EQ(signals[0], true);
1248  EXPECT_EQ(signals[1], false);
1249  EXPECT_EQ(signals[2], true);
1250  EXPECT_EQ(signals[3], false);
1251  EXPECT_EQ(signals[4], true);
1252  EXPECT_EQ(signals[5], true);
1253  EXPECT_EQ(signals[6], false);
1254  EXPECT_EQ(signals[7], true);
1255  EXPECT_EQ(signals[8], false);
1256  EXPECT_EQ(signals[9], true);
1257 
1258  // Using __weight__ should work as well,
1259  // the only difference to using _weight__ instead of g is
1260  // in setBranchAddresses which avoids calling makeROOTCompatible
1261  // So we have to check the behaviour using __weight__ as well
1262  general_options.m_weight_variable = "__weight__";
1263  MVA::ROOTDataset y(general_options);
1264 
1265  weights = y.getWeights();
1266  EXPECT_EQ(weights.size(), 10);
1267  EXPECT_FLOAT_EQ(weights[0], 1.2);
1268  EXPECT_FLOAT_EQ(weights[1], 2.2);
1269  EXPECT_FLOAT_EQ(weights[2], 3.2);
1270  EXPECT_FLOAT_EQ(weights[3], 4.2);
1271  EXPECT_FLOAT_EQ(weights[4], 5.2);
1272  EXPECT_FLOAT_EQ(weights[5], 1.2);
1273  EXPECT_FLOAT_EQ(weights[6], 2.2);
1274  EXPECT_FLOAT_EQ(weights[7], 3.2);
1275  EXPECT_FLOAT_EQ(weights[8], 4.2);
1276  EXPECT_FLOAT_EQ(weights[9], 5.2);
1277 
1278  // Check TChain expansion
1279  general_options.m_datafiles = {"datafile*.root"};
1280  {
1281  MVA::ROOTDataset chain_test(general_options);
1282  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1283  }
1284  boost::filesystem::copy_file("datafile.root", "datafile3.root");
1285  {
1286  MVA::ROOTDataset chain_test(general_options);
1287  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1288  }
1289  boost::filesystem::copy_file("datafile.root", "datafile4.root");
1290  {
1291  MVA::ROOTDataset chain_test(general_options);
1292  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1293  }
1294  // Test m_max_events feature
1295  {
1296  general_options.m_max_events = 10;
1297  MVA::ROOTDataset chain_test(general_options);
1298  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1299  general_options.m_max_events = 0;
1300  }
1301 
1302  // If a file exists with the specified expansion
1303  // the file takes precedence over the expansion
1304  boost::filesystem::copy_file("datafile.root", "datafile*.root");
1305  {
1306  general_options.m_max_events = 0;
1307  MVA::ROOTDataset chain_test(general_options);
1308  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1309  }
1310 
1311  }
1312  TEST(DatasetTest, ROOTMultiDatasetDouble)
1313  {
1314 
1316  TFile file("datafile.root", "RECREATE");
1317  file.cd();
1318  TTree tree("tree", "TreeTitle");
1319  double a, b, c, d, e, f, g, v, w = 0;
1320  tree.Branch("a", &a, "a/D");
1321  tree.Branch("b", &b, "b/D");
1322  tree.Branch("c", &c, "c/D");
1323  tree.Branch("d", &d, "d/D");
1324  tree.Branch("e__bo__bc", &e, "e__bo__bc/D");
1325  tree.Branch("f__bo__bc", &f, "f__bo__bc/D");
1326  tree.Branch("g", &g, "g/D");
1327  tree.Branch("__weight__", &c, "__weight__/D");
1328  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
1329  tree.Branch("w", &w, "w/D");
1330 
1331  for (unsigned int i = 0; i < 5; ++i) {
1332  a = i + 1.0;
1333  b = i + 1.1;
1334  c = i + 1.2;
1335  d = i + 1.3;
1336  e = i + 1.4;
1337  f = i + 1.5;
1338  g = float(i % 2 == 0);
1339  w = i + 1.6;
1340  v = i + 1.7;
1341  tree.Fill();
1342  }
1343 
1344  file.Write("tree");
1345 
1346  TFile file2("datafile2.root", "RECREATE");
1347  file2.cd();
1348  TTree tree2("tree", "TreeTitle");
1349  tree2.Branch("a", &a);
1350  tree2.Branch("b", &b);
1351  tree2.Branch("c", &c);
1352  tree2.Branch("d", &d);
1353  tree2.Branch("e__bo__bc", &e);
1354  tree2.Branch("f__bo__bc", &f);
1355  tree2.Branch("g", &g);
1356  tree2.Branch("__weight__", &c);
1357  tree2.Branch("v__bo__bc", &v);
1358  tree2.Branch("w", &w);
1359 
1360  for (unsigned int i = 0; i < 5; ++i) {
1361  a = i + 1.0;
1362  b = i + 1.1;
1363  c = i + 1.2;
1364  d = i + 1.3;
1365  e = i + 1.4;
1366  f = i + 1.5;
1367  g = float(i % 2 == 0);
1368  w = i + 1.6;
1369  v = i + 1.7;
1370  tree2.Fill();
1371  }
1372 
1373  file2.Write("tree");
1374 
1375  MVA::GeneralOptions general_options;
1376  // Both names with and without makeROOTCompatible should work
1377  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1378  general_options.m_spectators = {"w", "v()"};
1379  general_options.m_signal_class = 1;
1380  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1381  general_options.m_treename = "tree";
1382  general_options.m_target_variable = "g";
1383  general_options.m_weight_variable = "c";
1384  MVA::ROOTDataset x(general_options);
1385 
1386  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1387  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1388  EXPECT_EQ(x.getNumberOfEvents(), 10);
1389 
1390  // Should just work
1391  x.loadEvent(0);
1392  EXPECT_EQ(x.m_input.size(), 4);
1393  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1394  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1395  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1396  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1397  EXPECT_EQ(x.m_spectators.size(), 2);
1398  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1399  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1400  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1401  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1402  EXPECT_EQ(x.m_isSignal, true);
1403 
1404  x.loadEvent(5);
1405  EXPECT_EQ(x.m_input.size(), 4);
1406  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1407  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1408  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1409  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1410  EXPECT_EQ(x.m_spectators.size(), 2);
1411  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1412  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1413  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1414  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1415  EXPECT_EQ(x.m_isSignal, true);
1416 
1417  x.loadEvent(1);
1418  EXPECT_EQ(x.m_input.size(), 4);
1419  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1420  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1421  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1422  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1423  EXPECT_EQ(x.m_spectators.size(), 2);
1424  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1425  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1426  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1427  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1428  EXPECT_EQ(x.m_isSignal, false);
1429 
1430  x.loadEvent(6);
1431  EXPECT_EQ(x.m_input.size(), 4);
1432  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1433  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1434  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1435  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1436  EXPECT_EQ(x.m_spectators.size(), 2);
1437  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1438  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1439  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1440  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1441  EXPECT_EQ(x.m_isSignal, false);
1442 
1443  x.loadEvent(2);
1444  EXPECT_EQ(x.m_input.size(), 4);
1445  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1446  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1447  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1448  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1449  EXPECT_EQ(x.m_spectators.size(), 2);
1450  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1451  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1452  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1453  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1454  EXPECT_EQ(x.m_isSignal, true);
1455 
1456  x.loadEvent(7);
1457  EXPECT_EQ(x.m_input.size(), 4);
1458  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1459  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1460  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1461  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1462  EXPECT_EQ(x.m_spectators.size(), 2);
1463  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1464  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1465  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1466  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1467  EXPECT_EQ(x.m_isSignal, true);
1468 
1469  x.loadEvent(3);
1470  EXPECT_EQ(x.m_input.size(), 4);
1471  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1472  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1473  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1474  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1475  EXPECT_EQ(x.m_spectators.size(), 2);
1476  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1477  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1478  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1479  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1480  EXPECT_EQ(x.m_isSignal, false);
1481 
1482  x.loadEvent(8);
1483  EXPECT_EQ(x.m_input.size(), 4);
1484  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1485  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1486  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1487  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1488  EXPECT_EQ(x.m_spectators.size(), 2);
1489  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1490  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1491  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1492  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1493  EXPECT_EQ(x.m_isSignal, false);
1494 
1495  x.loadEvent(4);
1496  EXPECT_EQ(x.m_input.size(), 4);
1497  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1498  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1499  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1500  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1501  EXPECT_EQ(x.m_spectators.size(), 2);
1502  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1503  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1504  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1505  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1506  EXPECT_EQ(x.m_isSignal, true);
1507 
1508  x.loadEvent(9);
1509  EXPECT_EQ(x.m_input.size(), 4);
1510  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1511  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1512  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1513  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1514  EXPECT_EQ(x.m_spectators.size(), 2);
1515  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1516  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1517  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1518  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1519  EXPECT_EQ(x.m_isSignal, true);
1520 
1521  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1522 
1523  auto feature = x.getFeature(1);
1524  EXPECT_EQ(feature.size(), 10);
1525  EXPECT_FLOAT_EQ(feature[0], 1.1);
1526  EXPECT_FLOAT_EQ(feature[1], 2.1);
1527  EXPECT_FLOAT_EQ(feature[2], 3.1);
1528  EXPECT_FLOAT_EQ(feature[3], 4.1);
1529  EXPECT_FLOAT_EQ(feature[4], 5.1);
1530  EXPECT_FLOAT_EQ(feature[5], 1.1);
1531  EXPECT_FLOAT_EQ(feature[6], 2.1);
1532  EXPECT_FLOAT_EQ(feature[7], 3.1);
1533  EXPECT_FLOAT_EQ(feature[8], 4.1);
1534  EXPECT_FLOAT_EQ(feature[9], 5.1);
1535 
1536  // Same result for mother class implementation
1537  feature = x.Dataset::getFeature(1);
1538  EXPECT_EQ(feature.size(), 10);
1539  EXPECT_FLOAT_EQ(feature[0], 1.1);
1540  EXPECT_FLOAT_EQ(feature[1], 2.1);
1541  EXPECT_FLOAT_EQ(feature[2], 3.1);
1542  EXPECT_FLOAT_EQ(feature[3], 4.1);
1543  EXPECT_FLOAT_EQ(feature[4], 5.1);
1544  EXPECT_FLOAT_EQ(feature[5], 1.1);
1545  EXPECT_FLOAT_EQ(feature[6], 2.1);
1546  EXPECT_FLOAT_EQ(feature[7], 3.1);
1547  EXPECT_FLOAT_EQ(feature[8], 4.1);
1548  EXPECT_FLOAT_EQ(feature[9], 5.1);
1549 
1550  auto spectator = x.getSpectator(1);
1551  EXPECT_EQ(spectator.size(), 10);
1552  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1553  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1554  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1555  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1556  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1557  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1558  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1559  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1560  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1561  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1562 
1563  // Same result for mother class implementation
1564  spectator = x.Dataset::getSpectator(1);
1565  EXPECT_EQ(spectator.size(), 10);
1566  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1567  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1568  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1569  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1570  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1571  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1572  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1573  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1574  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1575  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1576 
1577  auto weights = x.getWeights();
1578  EXPECT_EQ(weights.size(), 10);
1579  EXPECT_FLOAT_EQ(weights[0], 1.2);
1580  EXPECT_FLOAT_EQ(weights[1], 2.2);
1581  EXPECT_FLOAT_EQ(weights[2], 3.2);
1582  EXPECT_FLOAT_EQ(weights[3], 4.2);
1583  EXPECT_FLOAT_EQ(weights[4], 5.2);
1584  EXPECT_FLOAT_EQ(weights[5], 1.2);
1585  EXPECT_FLOAT_EQ(weights[6], 2.2);
1586  EXPECT_FLOAT_EQ(weights[7], 3.2);
1587  EXPECT_FLOAT_EQ(weights[8], 4.2);
1588  EXPECT_FLOAT_EQ(weights[9], 5.2);
1589 
1590  // Same result for mother class implementation
1591  weights = x.Dataset::getWeights();
1592  EXPECT_EQ(weights.size(), 10);
1593  EXPECT_FLOAT_EQ(weights[0], 1.2);
1594  EXPECT_FLOAT_EQ(weights[1], 2.2);
1595  EXPECT_FLOAT_EQ(weights[2], 3.2);
1596  EXPECT_FLOAT_EQ(weights[3], 4.2);
1597  EXPECT_FLOAT_EQ(weights[4], 5.2);
1598  EXPECT_FLOAT_EQ(weights[5], 1.2);
1599  EXPECT_FLOAT_EQ(weights[6], 2.2);
1600  EXPECT_FLOAT_EQ(weights[7], 3.2);
1601  EXPECT_FLOAT_EQ(weights[8], 4.2);
1602  EXPECT_FLOAT_EQ(weights[9], 5.2);
1603 
1604  auto targets = x.getTargets();
1605  EXPECT_EQ(targets.size(), 10);
1606  EXPECT_FLOAT_EQ(targets[0], 1.0);
1607  EXPECT_FLOAT_EQ(targets[1], 0.0);
1608  EXPECT_FLOAT_EQ(targets[2], 1.0);
1609  EXPECT_FLOAT_EQ(targets[3], 0.0);
1610  EXPECT_FLOAT_EQ(targets[4], 1.0);
1611  EXPECT_FLOAT_EQ(targets[5], 1.0);
1612  EXPECT_FLOAT_EQ(targets[6], 0.0);
1613  EXPECT_FLOAT_EQ(targets[7], 1.0);
1614  EXPECT_FLOAT_EQ(targets[8], 0.0);
1615  EXPECT_FLOAT_EQ(targets[9], 1.0);
1616 
1617  auto signals = x.getSignals();
1618  EXPECT_EQ(signals.size(), 10);
1619  EXPECT_EQ(signals[0], true);
1620  EXPECT_EQ(signals[1], false);
1621  EXPECT_EQ(signals[2], true);
1622  EXPECT_EQ(signals[3], false);
1623  EXPECT_EQ(signals[4], true);
1624  EXPECT_EQ(signals[5], true);
1625  EXPECT_EQ(signals[6], false);
1626  EXPECT_EQ(signals[7], true);
1627  EXPECT_EQ(signals[8], false);
1628  EXPECT_EQ(signals[9], true);
1629 
1630  // Using __weight__ should work as well,
1631  // the only difference to using _weight__ instead of g is
1632  // in setBranchAddresses which avoids calling makeROOTCompatible
1633  // So we have to check the behaviour using __weight__ as well
1634  general_options.m_weight_variable = "__weight__";
1635  MVA::ROOTDataset y(general_options);
1636 
1637  weights = y.getWeights();
1638  EXPECT_EQ(weights.size(), 10);
1639  EXPECT_FLOAT_EQ(weights[0], 1.2);
1640  EXPECT_FLOAT_EQ(weights[1], 2.2);
1641  EXPECT_FLOAT_EQ(weights[2], 3.2);
1642  EXPECT_FLOAT_EQ(weights[3], 4.2);
1643  EXPECT_FLOAT_EQ(weights[4], 5.2);
1644  EXPECT_FLOAT_EQ(weights[5], 1.2);
1645  EXPECT_FLOAT_EQ(weights[6], 2.2);
1646  EXPECT_FLOAT_EQ(weights[7], 3.2);
1647  EXPECT_FLOAT_EQ(weights[8], 4.2);
1648  EXPECT_FLOAT_EQ(weights[9], 5.2);
1649 
1650  // Check TChain expansion
1651  general_options.m_datafiles = {"datafile*.root"};
1652  {
1653  MVA::ROOTDataset chain_test(general_options);
1654  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1655  }
1656  boost::filesystem::copy_file("datafile.root", "datafile3.root");
1657  {
1658  MVA::ROOTDataset chain_test(general_options);
1659  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1660  }
1661  boost::filesystem::copy_file("datafile.root", "datafile4.root");
1662  {
1663  MVA::ROOTDataset chain_test(general_options);
1664  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1665  }
1666  // Test m_max_events feature
1667  {
1668  general_options.m_max_events = 10;
1669  MVA::ROOTDataset chain_test(general_options);
1670  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1671  general_options.m_max_events = 0;
1672  }
1673 
1674  // If a file exists with the specified expansion
1675  // the file takes precedence over the expansion
1676  boost::filesystem::copy_file("datafile.root", "datafile*.root");
1677  {
1678  general_options.m_max_events = 0;
1679  MVA::ROOTDataset chain_test(general_options);
1680  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1681  }
1682  }
1683 }
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Definition: Dataset.h:292
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:31
General options which are shared by all MVA trainings.
Definition: Options.h:62
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:184
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
Definition: Dataset.h:347
Wraps the data of a single event into a Dataset.
Definition: Dataset.h:133
Wraps another Dataset and provides a view to a subset of its features and events.
Definition: Dataset.h:232
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:64
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.