Belle II Software  light-2303-iriomote
test_Dataset.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/interface/Dataset.h>
10 #include <framework/utilities/TestHelpers.h>
11 
12 #include <boost/filesystem/operations.hpp>
13 
14 #include <gtest/gtest.h>
15 
16 #include <fstream>
17 #include <numeric>
18 
19 using namespace Belle2;
20 
21 namespace {
22 
23  class TestDataset : public MVA::Dataset {
24  public:
25  explicit TestDataset(MVA::GeneralOptions& general_options) : MVA::Dataset(general_options)
26  {
27  m_input = {1.0, 2.0, 3.0, 4.0, 5.0};
28  m_target = 3.0;
29  m_isSignal = true;
30  m_weight = -3.0;
31  }
32 
33  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 5; }
34  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 2; }
35  [[nodiscard]] unsigned int getNumberOfEvents() const override { return 20; }
36  void loadEvent(unsigned int iEvent) override
37  {
38  auto f = static_cast<float>(iEvent);
39  m_input = {f + 1, f + 2, f + 3, f + 4, f + 5};
40  m_spectators = {f + 6, f + 7};
41  };
42  float getSignalFraction() override { return 0.1; };
43  std::vector<float> getFeature(unsigned int iFeature) override
44  {
45  std::vector<float> a(20, 0.0);
46  std::iota(a.begin(), a.end(), iFeature + 1);
47  return a;
48  }
49  std::vector<float> getSpectator(unsigned int iSpectator) override
50  {
51  std::vector<float> a(20, 0.0);
52  std::iota(a.begin(), a.end(), iSpectator + 6);
53  return a;
54  }
55 
56  };
57 
58  TEST(DatasetTest, SingleDataset)
59  {
60 
61  MVA::GeneralOptions general_options;
62  general_options.m_variables = {"a", "b", "c"};
63  general_options.m_spectators = {"e", "f"};
64  general_options.m_signal_class = 2;
65  std::vector<float> input = {1.0, 2.0, 3.0};
66  std::vector<float> spectators = {4.0, 5.0};
67 
68  MVA::SingleDataset x(general_options, input, 2.0, spectators);
69 
70  EXPECT_EQ(x.getNumberOfFeatures(), 3);
71  EXPECT_EQ(x.getNumberOfSpectators(), 2);
72  EXPECT_EQ(x.getNumberOfEvents(), 1);
73 
74  EXPECT_EQ(x.getFeatureIndex("a"), 0);
75  EXPECT_EQ(x.getFeatureIndex("b"), 1);
76  EXPECT_EQ(x.getFeatureIndex("c"), 2);
77  EXPECT_B2ERROR(x.getFeatureIndex("bla"));
78 
79  EXPECT_EQ(x.getSpectatorIndex("e"), 0);
80  EXPECT_EQ(x.getSpectatorIndex("f"), 1);
81  EXPECT_B2ERROR(x.getSpectatorIndex("bla"));
82 
83  // Should just work
84  x.loadEvent(0);
85 
86  EXPECT_EQ(x.m_input.size(), 3);
87  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
88  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
89  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
90 
91  EXPECT_EQ(x.m_spectators.size(), 2);
92  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.0);
93  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.0);
94 
95  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
96  EXPECT_FLOAT_EQ(x.m_target, 2.0);
97  EXPECT_EQ(x.m_isSignal, true);
98 
99  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1.0);
100 
101  auto feature = x.getFeature(1);
102  EXPECT_EQ(feature.size(), 1);
103  EXPECT_FLOAT_EQ(feature[0], 2.0);
104 
105  auto spectator = x.getSpectator(1);
106  EXPECT_EQ(spectator.size(), 1);
107  EXPECT_FLOAT_EQ(spectator[0], 5.0);
108 
109  // Same result for mother class implementation
110  feature = x.Dataset::getFeature(1);
111  EXPECT_EQ(feature.size(), 1);
112  EXPECT_FLOAT_EQ(feature[0], 2.0);
113 
114  }
115 
116  TEST(DatasetTest, MultiDataset)
117  {
118 
119  MVA::GeneralOptions general_options;
120  general_options.m_variables = {"a", "b", "c"};
121  general_options.m_spectators = {"e", "f"};
122  general_options.m_signal_class = 2;
123  std::vector<std::vector<float>> matrix = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}};
124  std::vector<std::vector<float>> spectator_matrix = {{12.0, 13.0}, {15.0, 16.0}, {18.0, 19.0}};
125  std::vector<float> targets = {2.0, 0.0, 2.0};
126  std::vector<float> weights = {1.0, 2.0, 3.0};
127 
128  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, {1.0}, weights));
129 
130  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, targets, {1.0}));
131 
132  MVA::MultiDataset x(general_options, matrix, spectator_matrix, targets, weights);
133 
134  EXPECT_EQ(x.getNumberOfFeatures(), 3);
135  EXPECT_EQ(x.getNumberOfEvents(), 3);
136 
137  // Should just work
138  x.loadEvent(0);
139 
140  EXPECT_EQ(x.m_input.size(), 3);
141  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
142  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
143  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
144 
145  EXPECT_EQ(x.m_spectators.size(), 2);
146  EXPECT_FLOAT_EQ(x.m_spectators[0], 12.0);
147  EXPECT_FLOAT_EQ(x.m_spectators[1], 13.0);
148 
149  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
150  EXPECT_FLOAT_EQ(x.m_target, 2.0);
151  EXPECT_EQ(x.m_isSignal, true);
152 
153  // Should just work
154  x.loadEvent(1);
155 
156  EXPECT_EQ(x.m_input.size(), 3);
157  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
158  EXPECT_FLOAT_EQ(x.m_input[1], 5.0);
159  EXPECT_FLOAT_EQ(x.m_input[2], 6.0);
160 
161  EXPECT_EQ(x.m_spectators.size(), 2);
162  EXPECT_FLOAT_EQ(x.m_spectators[0], 15.0);
163  EXPECT_FLOAT_EQ(x.m_spectators[1], 16.0);
164 
165  EXPECT_FLOAT_EQ(x.m_weight, 2.0);
166  EXPECT_FLOAT_EQ(x.m_target, 0.0);
167  EXPECT_EQ(x.m_isSignal, false);
168 
169  // Should just work
170  x.loadEvent(2);
171 
172  EXPECT_EQ(x.m_input.size(), 3);
173  EXPECT_FLOAT_EQ(x.m_input[0], 7.0);
174  EXPECT_FLOAT_EQ(x.m_input[1], 8.0);
175  EXPECT_FLOAT_EQ(x.m_input[2], 9.0);
176 
177  EXPECT_EQ(x.m_spectators.size(), 2);
178  EXPECT_FLOAT_EQ(x.m_spectators[0], 18.0);
179  EXPECT_FLOAT_EQ(x.m_spectators[1], 19.0);
180 
181  EXPECT_FLOAT_EQ(x.m_weight, 3.0);
182  EXPECT_FLOAT_EQ(x.m_target, 2.0);
183  EXPECT_EQ(x.m_isSignal, true);
184 
185  EXPECT_FLOAT_EQ(x.getSignalFraction(), 4.0 / 6.0);
186 
187  auto feature = x.getFeature(1);
188  EXPECT_EQ(feature.size(), 3);
189  EXPECT_FLOAT_EQ(feature[0], 2.0);
190  EXPECT_FLOAT_EQ(feature[1], 5.0);
191  EXPECT_FLOAT_EQ(feature[2], 8.0);
192 
193  auto spectator = x.getSpectator(1);
194  EXPECT_EQ(spectator.size(), 3);
195  EXPECT_FLOAT_EQ(spectator[0], 13.0);
196  EXPECT_FLOAT_EQ(spectator[1], 16.0);
197  EXPECT_FLOAT_EQ(spectator[2], 19.0);
198 
199 
200  }
201 
202  TEST(DatasetTest, SubDataset)
203  {
204 
205  MVA::GeneralOptions general_options;
206  general_options.m_signal_class = 3;
207  general_options.m_variables = {"a", "b", "c", "d", "e"};
208  general_options.m_spectators = {"f", "g"};
209  TestDataset test_dataset(general_options);
210 
211  general_options.m_variables = {"a", "d", "e"};
212  general_options.m_spectators = {"g"};
213  std::vector<bool> events = {true, false, true, false, true, false, true, false, true, false,
214  true, false, true, false, true, false, true, false, true, false
215  };
216  MVA::SubDataset x(general_options, events, test_dataset);
217 
218  EXPECT_EQ(x.getNumberOfFeatures(), 3);
219  EXPECT_EQ(x.getNumberOfEvents(), 10);
220 
221  // Should just work
222  x.loadEvent(0);
223 
224  EXPECT_EQ(x.m_input.size(), 3);
225  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
226  EXPECT_FLOAT_EQ(x.m_input[1], 4.0);
227  EXPECT_FLOAT_EQ(x.m_input[2], 5.0);
228 
229  EXPECT_EQ(x.m_spectators.size(), 1);
230  EXPECT_FLOAT_EQ(x.m_spectators[0], 7.0);
231 
232  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
233  EXPECT_FLOAT_EQ(x.m_target, 3.0);
234  EXPECT_EQ(x.m_isSignal, true);
235 
236  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1);
237 
238  auto feature = x.getFeature(1);
239  EXPECT_EQ(feature.size(), 10);
240  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
241  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
242  };
243 
244  auto spectator = x.getSpectator(0);
245  EXPECT_EQ(spectator.size(), 10);
246  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
247  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
248  };
249 
250  // Same result for mother class implementation
251  feature = x.Dataset::getFeature(1);
252  EXPECT_EQ(feature.size(), 10);
253  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
254  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
255  };
256 
257  spectator = x.Dataset::getSpectator(0);
258  EXPECT_EQ(spectator.size(), 10);
259  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
260  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
261  };
262 
263  // Test without event indices
264  MVA::SubDataset y(general_options, {}, test_dataset);
265  feature = y.getFeature(1);
266  EXPECT_EQ(feature.size(), 20);
267  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
268  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
269  };
270 
271  spectator = y.getSpectator(0);
272  EXPECT_EQ(spectator.size(), 20);
273  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
274  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
275  };
276 
277  // Same result for mother class implementation
278  feature = y.Dataset::getFeature(1);
279  EXPECT_EQ(feature.size(), 20);
280  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
281  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
282  };
283 
284  spectator = y.Dataset::getSpectator(0);
285  EXPECT_EQ(spectator.size(), 20);
286  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
287  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
288  };
289 
290  general_options.m_variables = {"a", "d", "e", "DOESNOTEXIST"};
291  try {
292  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
293  } catch (...) {
294 
295  }
296  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
297 
298  general_options.m_variables = {"a", "d", "e"};
299  general_options.m_spectators = {"DOESNOTEXIST"};
300  try {
301  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
302  } catch (...) {
303 
304  }
305  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
306 
307  }
308 
309  TEST(DatasetTest, CombinedDataset)
310  {
311 
312  MVA::GeneralOptions general_options;
313  general_options.m_signal_class = 1;
314  general_options.m_variables = {"a", "b", "c", "d", "e"};
315  general_options.m_spectators = {"f", "g"};
316  TestDataset signal_dataset(general_options);
317  TestDataset bckgrd_dataset(general_options);
318 
319  MVA::CombinedDataset x(general_options, signal_dataset, bckgrd_dataset);
320 
321  EXPECT_EQ(x.getNumberOfFeatures(), 5);
322  EXPECT_EQ(x.getNumberOfEvents(), 40);
323 
324  // Should just work
325  x.loadEvent(0);
326 
327  EXPECT_EQ(x.m_input.size(), 5);
328  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
329  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
330  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
331  EXPECT_FLOAT_EQ(x.m_input[3], 4.0);
332  EXPECT_FLOAT_EQ(x.m_input[4], 5.0);
333 
334  EXPECT_EQ(x.m_spectators.size(), 2);
335  EXPECT_FLOAT_EQ(x.m_spectators[0], 6.0);
336  EXPECT_FLOAT_EQ(x.m_spectators[1], 7.0);
337 
338  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
339  EXPECT_FLOAT_EQ(x.m_target, 1.0);
340  EXPECT_EQ(x.m_isSignal, true);
341 
342  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.5);
343 
344  auto feature = x.getFeature(1);
345  EXPECT_EQ(feature.size(), 40);
346  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
347  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
348  };
349 
350  auto spectator = x.getSpectator(0);
351  EXPECT_EQ(spectator.size(), 40);
352  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
353  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
354  };
355 
356  // Same result for mother class implementation
357  feature = x.Dataset::getFeature(1);
358  EXPECT_EQ(feature.size(), 40);
359  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
360  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
361  };
362 
363  spectator = x.Dataset::getSpectator(0);
364  EXPECT_EQ(spectator.size(), 40);
365  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
366  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
367  };
368 
369  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
370  x.loadEvent(iEvent);
371  EXPECT_EQ(x.m_isSignal, iEvent < 20);
372  EXPECT_FLOAT_EQ(x.m_target, iEvent < 20 ? 1.0 : 0.0);
373  }
374 
375  }
376 
377  TEST(DatasetTest, ROOTDataset)
378  {
379 
381  TFile file("datafile.root", "RECREATE");
382  file.cd();
383  TTree tree("tree", "TreeTitle");
384  float a, b, c, d, e, f, v, w = 0;
385  bool target;
386  tree.Branch("a", &a);
387  tree.Branch("b", &b);
388  tree.Branch("c", &c);
389  tree.Branch("d", &d);
390  tree.Branch("e__bo__bc", &e);
391  tree.Branch("f__bo__bc", &f);
392  tree.Branch("g", &target);
393  tree.Branch("__weight__", &c);
394  tree.Branch("v__bo__bc", &v);
395  tree.Branch("w", &w);
396 
397  for (unsigned int i = 0; i < 5; ++i) {
398  a = i + 1.0;
399  b = i + 1.1;
400  c = i + 1.2;
401  d = i + 1.3;
402  e = i + 1.4;
403  f = i + 1.5;
404  target = (i % 2 == 0);
405  w = i + 1.6;
406  v = i + 1.7;
407  tree.Fill();
408  }
409 
410  file.Write("tree");
411 
412  MVA::GeneralOptions general_options;
413  // Both names with and without makeROOTCompatible should work
414  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
415  general_options.m_spectators = {"w", "v()"};
416  general_options.m_signal_class = 1;
417  general_options.m_datafiles = {"datafile.root"};
418  general_options.m_treename = "tree";
419  general_options.m_target_variable = "g";
420  general_options.m_weight_variable = "c";
421  MVA::ROOTDataset x(general_options);
422 
423  EXPECT_EQ(x.getNumberOfFeatures(), 4);
424  EXPECT_EQ(x.getNumberOfSpectators(), 2);
425  EXPECT_EQ(x.getNumberOfEvents(), 5);
426 
427  // Should just work
428  x.loadEvent(0);
429  EXPECT_EQ(x.m_input.size(), 4);
430  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
431  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
432  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
433  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
434  EXPECT_EQ(x.m_spectators.size(), 2);
435  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
436  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
437  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
438  EXPECT_EQ(x.m_target, true);
439  EXPECT_EQ(x.m_isSignal, true);
440 
441  x.loadEvent(1);
442  EXPECT_EQ(x.m_input.size(), 4);
443  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
444  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
445  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
446  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
447  EXPECT_EQ(x.m_spectators.size(), 2);
448  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
449  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
450  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
451  EXPECT_EQ(x.m_target, false);
452  EXPECT_EQ(x.m_isSignal, false);
453 
454  x.loadEvent(2);
455  EXPECT_EQ(x.m_input.size(), 4);
456  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
457  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
458  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
459  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
460  EXPECT_EQ(x.m_spectators.size(), 2);
461  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
462  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
463  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
464  EXPECT_EQ(x.m_target, true);
465  EXPECT_EQ(x.m_isSignal, true);
466 
467  x.loadEvent(3);
468  EXPECT_EQ(x.m_input.size(), 4);
469  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
470  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
471  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
472  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
473  EXPECT_EQ(x.m_spectators.size(), 2);
474  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
475  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
476  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
477  EXPECT_EQ(x.m_target, false);
478  EXPECT_EQ(x.m_isSignal, false);
479 
480  x.loadEvent(4);
481  EXPECT_EQ(x.m_input.size(), 4);
482  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
483  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
484  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
485  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
486  EXPECT_EQ(x.m_spectators.size(), 2);
487  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
488  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
489  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
490  EXPECT_EQ(x.m_target, true);
491  EXPECT_EQ(x.m_isSignal, true);
492 
493  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
494 
495  auto feature = x.getFeature(1);
496  EXPECT_EQ(feature.size(), 5);
497  EXPECT_FLOAT_EQ(feature[0], 1.1);
498  EXPECT_FLOAT_EQ(feature[1], 2.1);
499  EXPECT_FLOAT_EQ(feature[2], 3.1);
500  EXPECT_FLOAT_EQ(feature[3], 4.1);
501  EXPECT_FLOAT_EQ(feature[4], 5.1);
502 
503  // Same result for mother class implementation
504  feature = x.Dataset::getFeature(1);
505  EXPECT_EQ(feature.size(), 5);
506  EXPECT_FLOAT_EQ(feature[0], 1.1);
507  EXPECT_FLOAT_EQ(feature[1], 2.1);
508  EXPECT_FLOAT_EQ(feature[2], 3.1);
509  EXPECT_FLOAT_EQ(feature[3], 4.1);
510  EXPECT_FLOAT_EQ(feature[4], 5.1);
511 
512  auto spectator = x.getSpectator(1);
513  EXPECT_EQ(spectator.size(), 5);
514  EXPECT_FLOAT_EQ(spectator[0], 1.7);
515  EXPECT_FLOAT_EQ(spectator[1], 2.7);
516  EXPECT_FLOAT_EQ(spectator[2], 3.7);
517  EXPECT_FLOAT_EQ(spectator[3], 4.7);
518  EXPECT_FLOAT_EQ(spectator[4], 5.7);
519 
520  // Same result for mother class implementation
521  spectator = x.Dataset::getSpectator(1);
522  EXPECT_EQ(spectator.size(), 5);
523  EXPECT_FLOAT_EQ(spectator[0], 1.7);
524  EXPECT_FLOAT_EQ(spectator[1], 2.7);
525  EXPECT_FLOAT_EQ(spectator[2], 3.7);
526  EXPECT_FLOAT_EQ(spectator[3], 4.7);
527  EXPECT_FLOAT_EQ(spectator[4], 5.7);
528 
529  auto weights = x.getWeights();
530  EXPECT_EQ(weights.size(), 5);
531  EXPECT_FLOAT_EQ(weights[0], 1.2);
532  EXPECT_FLOAT_EQ(weights[1], 2.2);
533  EXPECT_FLOAT_EQ(weights[2], 3.2);
534  EXPECT_FLOAT_EQ(weights[3], 4.2);
535  EXPECT_FLOAT_EQ(weights[4], 5.2);
536 
537  // Same result for mother class implementation
538  weights = x.Dataset::getWeights();
539  EXPECT_EQ(weights.size(), 5);
540  EXPECT_FLOAT_EQ(weights[0], 1.2);
541  EXPECT_FLOAT_EQ(weights[1], 2.2);
542  EXPECT_FLOAT_EQ(weights[2], 3.2);
543  EXPECT_FLOAT_EQ(weights[3], 4.2);
544  EXPECT_FLOAT_EQ(weights[4], 5.2);
545 
546  auto targets = x.getTargets();
547  EXPECT_EQ(targets.size(), 5);
548  EXPECT_EQ(targets[0], true);
549  EXPECT_EQ(targets[1], false);
550  EXPECT_EQ(targets[2], true);
551  EXPECT_EQ(targets[3], false);
552  EXPECT_EQ(targets[4], true);
553 
554  auto signals = x.getSignals();
555  EXPECT_EQ(signals.size(), 5);
556  EXPECT_EQ(signals[0], true);
557  EXPECT_EQ(signals[1], false);
558  EXPECT_EQ(signals[2], true);
559  EXPECT_EQ(signals[3], false);
560  EXPECT_EQ(signals[4], true);
561 
562  // Using __weight__ should work as well,
563  // the only difference to using _weight__ instead of c is
564  // in setBranchAddresses which avoids calling makeROOTCompatible
565  // So we have to check the behaviour using __weight__ as well
566  general_options.m_weight_variable = "__weight__";
567  MVA::ROOTDataset y(general_options);
568 
569  weights = y.getWeights();
570  EXPECT_EQ(weights.size(), 5);
571  EXPECT_FLOAT_EQ(weights[0], 1.2);
572  EXPECT_FLOAT_EQ(weights[1], 2.2);
573  EXPECT_FLOAT_EQ(weights[2], 3.2);
574  EXPECT_FLOAT_EQ(weights[3], 4.2);
575  EXPECT_FLOAT_EQ(weights[4], 5.2);
576 
577  // Check TChain expansion
578  general_options.m_datafiles = {"datafile*.root"};
579  {
580  MVA::ROOTDataset chain_test(general_options);
581  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
582  }
583  boost::filesystem::copy_file("datafile.root", "datafile2.root");
584  {
585  MVA::ROOTDataset chain_test(general_options);
586  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
587  }
588  boost::filesystem::copy_file("datafile.root", "datafile3.root");
589  {
590  MVA::ROOTDataset chain_test(general_options);
591  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
592  }
593  // Test m_max_events feature
594  {
595  general_options.m_max_events = 10;
596  MVA::ROOTDataset chain_test(general_options);
597  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
598  general_options.m_max_events = 0;
599  }
600 
601  // Check for missing tree
602  general_options.m_treename = "missing tree";
603  try {
604  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
605  } catch (...) {
606 
607  }
608  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
609 
610  // Check for missing branch
611  general_options.m_treename = "tree";
612  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
613  try {
614  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
615  } catch (...) {
616 
617  }
618  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
619 
620  // Check for missing branch
621  general_options.m_treename = "tree";
622  general_options.m_variables = {"a", "b", "e", "f"};
623  general_options.m_spectators = {"missing branch"};
624  try {
625  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
626  } catch (...) {
627 
628  }
629  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
630 
631  // Check for missing file
632  general_options.m_spectators = {};
633  general_options.m_datafiles = {"DOESNOTEXIST.root"};
634  general_options.m_treename = "tree";
635  try {
636  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
637  } catch (...) {
638 
639  }
640  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
641 
642  // Check for invalid file
643  general_options.m_datafiles = {"ISNotAValidROOTFile"};
644  general_options.m_treename = "tree";
645 
646  {
647  std::ofstream(general_options.m_datafiles[0]);
648  }
649  EXPECT_TRUE(boost::filesystem::exists(general_options.m_datafiles[0]));
650 
651  try {
652  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
653  } catch (...) {
654 
655  }
656  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
657  }
658 
659 
660  TEST(DatasetTest, ROOTDatasetMixedTypes)
661  {
662 
664  TFile file("datafile.root", "RECREATE");
665  file.cd();
666  TTree tree("tree", "TreeTitle");
667  double a, c, d, g, v, w = 0;
668  float b = 0.0f;
669  int e = 0;
670  bool f = false;
671  tree.Branch("a", &a, "a/D");
672  tree.Branch("b", &b, "b/F");
673  tree.Branch("c", &c, "c/D");
674  tree.Branch("d", &d, "d/D");
675  tree.Branch("e__bo__bc", &e, "e__bo__bc/I");
676  tree.Branch("f__bo__bc", &f, "f__bo__bc/O");
677  tree.Branch("g", &g, "g/D");
678  tree.Branch("__weight__", &c, "__weight__/D");
679  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
680  tree.Branch("w", &w, "w/D");
681 
682  for (unsigned int i = 0; i < 5; ++i) {
683  a = i + 1.0;
684  b = i + 1.1;
685  c = i + 1.2;
686  d = i + 1.3;
687  e = i + 1;
688  f = i % 2 == 0;
689  g = float(i % 2 == 0);
690  w = i + 1.6;
691  v = i + 1.7;
692  tree.Fill();
693  }
694 
695  file.Write("tree");
696 
697  MVA::GeneralOptions general_options;
698  // Both names with and without makeROOTCompatible should work
699  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
700  general_options.m_spectators = {"w", "v()"};
701  general_options.m_signal_class = 1;
702  general_options.m_datafiles = {"datafile.root"};
703  general_options.m_treename = "tree";
704  general_options.m_target_variable = "g";
705  general_options.m_weight_variable = "c";
706  MVA::ROOTDataset x(general_options);
707 
708  EXPECT_EQ(x.getNumberOfFeatures(), 4);
709  EXPECT_EQ(x.getNumberOfSpectators(), 2);
710  EXPECT_EQ(x.getNumberOfEvents(), 5);
711 
712  // Should just work
713  x.loadEvent(0);
714  EXPECT_EQ(x.m_input.size(), 4);
715  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
716  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
717  EXPECT_EQ(x.m_input[2], 1);
718  EXPECT_EQ(x.m_input[3], true);
719  EXPECT_EQ(x.m_spectators.size(), 2);
720  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
721  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
722  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
723  EXPECT_FLOAT_EQ(x.m_target, 1.0);
724  EXPECT_EQ(x.m_isSignal, true);
725 
726  x.loadEvent(1);
727  EXPECT_EQ(x.m_input.size(), 4);
728  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
729  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
730  EXPECT_EQ(x.m_input[2], 2);
731  EXPECT_EQ(x.m_input[3], false);
732  EXPECT_EQ(x.m_spectators.size(), 2);
733  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
734  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
735  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
736  EXPECT_FLOAT_EQ(x.m_target, 0.0);
737  EXPECT_EQ(x.m_isSignal, false);
738 
739  x.loadEvent(2);
740  EXPECT_EQ(x.m_input.size(), 4);
741  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
742  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
743  EXPECT_EQ(x.m_input[2], 3);
744  EXPECT_EQ(x.m_input[3], true);
745  EXPECT_EQ(x.m_spectators.size(), 2);
746  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
747  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
748  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
749  EXPECT_FLOAT_EQ(x.m_target, 1.0);
750  EXPECT_EQ(x.m_isSignal, true);
751 
752  x.loadEvent(3);
753  EXPECT_EQ(x.m_input.size(), 4);
754  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
755  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
756  EXPECT_EQ(x.m_input[2], 4);
757  EXPECT_EQ(x.m_input[3], false);
758  EXPECT_EQ(x.m_spectators.size(), 2);
759  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
760  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
761  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
762  EXPECT_FLOAT_EQ(x.m_target, 0.0);
763  EXPECT_EQ(x.m_isSignal, false);
764 
765  x.loadEvent(4);
766  EXPECT_EQ(x.m_input.size(), 4);
767  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
768  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
769  EXPECT_EQ(x.m_input[2], 5);
770  EXPECT_EQ(x.m_input[3], true);
771  EXPECT_EQ(x.m_spectators.size(), 2);
772  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
773  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
774  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
775  EXPECT_FLOAT_EQ(x.m_target, 1.0);
776  EXPECT_EQ(x.m_isSignal, true);
777 
778  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
779 
780  auto feature = x.getFeature(1);
781  EXPECT_EQ(feature.size(), 5);
782  EXPECT_FLOAT_EQ(feature[0], 1.1);
783  EXPECT_FLOAT_EQ(feature[1], 2.1);
784  EXPECT_FLOAT_EQ(feature[2], 3.1);
785  EXPECT_FLOAT_EQ(feature[3], 4.1);
786  EXPECT_FLOAT_EQ(feature[4], 5.1);
787 
788  // Same result for mother class implementation
789  feature = x.Dataset::getFeature(1);
790  EXPECT_EQ(feature.size(), 5);
791  EXPECT_FLOAT_EQ(feature[0], 1.1);
792  EXPECT_FLOAT_EQ(feature[1], 2.1);
793  EXPECT_FLOAT_EQ(feature[2], 3.1);
794  EXPECT_FLOAT_EQ(feature[3], 4.1);
795  EXPECT_FLOAT_EQ(feature[4], 5.1);
796 
797  auto spectator = x.getSpectator(1);
798  EXPECT_EQ(spectator.size(), 5);
799  EXPECT_FLOAT_EQ(spectator[0], 1.7);
800  EXPECT_FLOAT_EQ(spectator[1], 2.7);
801  EXPECT_FLOAT_EQ(spectator[2], 3.7);
802  EXPECT_FLOAT_EQ(spectator[3], 4.7);
803  EXPECT_FLOAT_EQ(spectator[4], 5.7);
804 
805  // Same result for mother class implementation
806  spectator = x.Dataset::getSpectator(1);
807  EXPECT_EQ(spectator.size(), 5);
808  EXPECT_FLOAT_EQ(spectator[0], 1.7);
809  EXPECT_FLOAT_EQ(spectator[1], 2.7);
810  EXPECT_FLOAT_EQ(spectator[2], 3.7);
811  EXPECT_FLOAT_EQ(spectator[3], 4.7);
812  EXPECT_FLOAT_EQ(spectator[4], 5.7);
813 
814  auto weights = x.getWeights();
815  EXPECT_EQ(weights.size(), 5);
816  EXPECT_FLOAT_EQ(weights[0], 1.2);
817  EXPECT_FLOAT_EQ(weights[1], 2.2);
818  EXPECT_FLOAT_EQ(weights[2], 3.2);
819  EXPECT_FLOAT_EQ(weights[3], 4.2);
820  EXPECT_FLOAT_EQ(weights[4], 5.2);
821 
822  // Same result for mother class implementation
823  weights = x.Dataset::getWeights();
824  EXPECT_EQ(weights.size(), 5);
825  EXPECT_FLOAT_EQ(weights[0], 1.2);
826  EXPECT_FLOAT_EQ(weights[1], 2.2);
827  EXPECT_FLOAT_EQ(weights[2], 3.2);
828  EXPECT_FLOAT_EQ(weights[3], 4.2);
829  EXPECT_FLOAT_EQ(weights[4], 5.2);
830 
831  auto targets = x.getTargets();
832  EXPECT_EQ(targets.size(), 5);
833  EXPECT_FLOAT_EQ(targets[0], 1.0);
834  EXPECT_FLOAT_EQ(targets[1], 0.0);
835  EXPECT_FLOAT_EQ(targets[2], 1.0);
836  EXPECT_FLOAT_EQ(targets[3], 0.0);
837  EXPECT_FLOAT_EQ(targets[4], 1.0);
838 
839  auto signals = x.getSignals();
840  EXPECT_EQ(signals.size(), 5);
841  EXPECT_EQ(signals[0], true);
842  EXPECT_EQ(signals[1], false);
843  EXPECT_EQ(signals[2], true);
844  EXPECT_EQ(signals[3], false);
845  EXPECT_EQ(signals[4], true);
846 
847  // Using __weight__ should work as well,
848  // the only difference to using _weight__ instead of g is
849  // in setBranchAddresses which avoids calling makeROOTCompatible
850  // So we have to check the behaviour using __weight__ as well
851  general_options.m_weight_variable = "__weight__";
852  MVA::ROOTDataset y(general_options);
853 
854  weights = y.getWeights();
855  EXPECT_EQ(weights.size(), 5);
856  EXPECT_FLOAT_EQ(weights[0], 1.2);
857  EXPECT_FLOAT_EQ(weights[1], 2.2);
858  EXPECT_FLOAT_EQ(weights[2], 3.2);
859  EXPECT_FLOAT_EQ(weights[3], 4.2);
860  EXPECT_FLOAT_EQ(weights[4], 5.2);
861 
862  // Check TChain expansion
863  general_options.m_datafiles = {"datafile*.root"};
864  {
865  MVA::ROOTDataset chain_test(general_options);
866  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
867  }
868  boost::filesystem::copy_file("datafile.root", "datafile2.root");
869  {
870  MVA::ROOTDataset chain_test(general_options);
871  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
872  }
873  boost::filesystem::copy_file("datafile.root", "datafile3.root");
874  {
875  MVA::ROOTDataset chain_test(general_options);
876  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
877  }
878  // Test m_max_events feature
879  {
880  general_options.m_max_events = 10;
881  MVA::ROOTDataset chain_test(general_options);
882  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
883  general_options.m_max_events = 0;
884  }
885 
886  // Check for missing tree
887  general_options.m_treename = "missing tree";
888  try {
889  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
890  } catch (...) {
891 
892  }
893  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
894 
895  // Check for missing branch
896  general_options.m_treename = "tree";
897  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
898  try {
899  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
900  } catch (...) {
901 
902  }
903  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
904 
905  // Check for missing branch
906  general_options.m_treename = "tree";
907  general_options.m_variables = {"a", "b", "e", "f"};
908  general_options.m_spectators = {"missing branch"};
909  try {
910  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
911  } catch (...) {
912 
913  }
914  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
915 
916  // Check for missing file
917  general_options.m_spectators = {};
918  general_options.m_datafiles = {"DOESNOTEXIST.root"};
919  general_options.m_treename = "tree";
920  try {
921  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
922  } catch (...) {
923 
924  }
925  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
926 
927  // Check for invalid file
928  general_options.m_datafiles = {"ISNotAValidROOTFile"};
929  general_options.m_treename = "tree";
930 
931  {
932  std::ofstream(general_options.m_datafiles[0]);
933  }
934  EXPECT_TRUE(boost::filesystem::exists(general_options.m_datafiles[0]));
935 
936  try {
937  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
938  } catch (...) {
939 
940  }
941  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
942  }
943 
944  TEST(DatasetTest, ROOTMultiDataset)
945  {
946 
948  TFile file("datafile.root", "RECREATE");
949  file.cd();
950  TTree tree("tree", "TreeTitle");
951  double a, b, c, d, e, f, v, w = 0;
952  bool target;
953  tree.Branch("a", &a);
954  tree.Branch("b", &b);
955  tree.Branch("c", &c);
956  tree.Branch("d", &d);
957  tree.Branch("e__bo__bc", &e);
958  tree.Branch("f__bo__bc", &f);
959  tree.Branch("g", &target);
960  tree.Branch("__weight__", &c);
961  tree.Branch("v__bo__bc", &v);
962  tree.Branch("w", &w);
963 
964  for (unsigned int i = 0; i < 5; ++i) {
965  a = i + 1.0;
966  b = i + 1.1;
967  c = i + 1.2;
968  d = i + 1.3;
969  e = i + 1.4;
970  f = i + 1.5;
971  target = (i % 2 == 0);
972  w = i + 1.6;
973  v = i + 1.7;
974  tree.Fill();
975  }
976 
977  file.Write("tree");
978 
979  TFile file2("datafile2.root", "RECREATE");
980  file2.cd();
981  TTree tree2("tree", "TreeTitle");
982  tree2.Branch("a", &a);
983  tree2.Branch("b", &b);
984  tree2.Branch("c", &c);
985  tree2.Branch("d", &d);
986  tree2.Branch("e__bo__bc", &e);
987  tree2.Branch("f__bo__bc", &f);
988  tree2.Branch("g", &target);
989  tree2.Branch("__weight__", &c);
990  tree2.Branch("v__bo__bc", &v);
991  tree2.Branch("w", &w);
992 
993  for (unsigned int i = 0; i < 5; ++i) {
994  a = i + 1.0;
995  b = i + 1.1;
996  c = i + 1.2;
997  d = i + 1.3;
998  e = i + 1.4;
999  f = i + 1.5;
1000  target = (i % 2 == 0);
1001  w = i + 1.6;
1002  v = i + 1.7;
1003  tree2.Fill();
1004  }
1005 
1006  file2.Write("tree");
1007 
1008  MVA::GeneralOptions general_options;
1009  // Both names with and without makeROOTCompatible should work
1010  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1011  general_options.m_spectators = {"w", "v()"};
1012  general_options.m_signal_class = 1;
1013  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1014  general_options.m_treename = "tree";
1015  general_options.m_target_variable = "g";
1016  general_options.m_weight_variable = "c";
1017  MVA::ROOTDataset x(general_options);
1018 
1019  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1020  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1021  EXPECT_EQ(x.getNumberOfEvents(), 10);
1022 
1023  // Should just work
1024  x.loadEvent(0);
1025  EXPECT_EQ(x.m_input.size(), 4);
1026  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1027  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1028  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1029  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1030  EXPECT_EQ(x.m_spectators.size(), 2);
1031  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1032  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1033  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1034  EXPECT_EQ(x.m_target, true);
1035  EXPECT_EQ(x.m_isSignal, true);
1036 
1037  x.loadEvent(5);
1038  EXPECT_EQ(x.m_input.size(), 4);
1039  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1040  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1041  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1042  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1043  EXPECT_EQ(x.m_spectators.size(), 2);
1044  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1045  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1046  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1047  EXPECT_EQ(x.m_target, true);
1048  EXPECT_EQ(x.m_isSignal, true);
1049 
1050  x.loadEvent(1);
1051  EXPECT_EQ(x.m_input.size(), 4);
1052  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1053  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1054  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1055  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1056  EXPECT_EQ(x.m_spectators.size(), 2);
1057  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1058  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1059  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1060  EXPECT_EQ(x.m_target, false);
1061  EXPECT_EQ(x.m_isSignal, false);
1062 
1063  x.loadEvent(6);
1064  EXPECT_EQ(x.m_input.size(), 4);
1065  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1066  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1067  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1068  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1069  EXPECT_EQ(x.m_spectators.size(), 2);
1070  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1071  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1072  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1073  EXPECT_EQ(x.m_target, false);
1074  EXPECT_EQ(x.m_isSignal, false);
1075 
1076  x.loadEvent(2);
1077  EXPECT_EQ(x.m_input.size(), 4);
1078  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1079  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1080  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1081  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1082  EXPECT_EQ(x.m_spectators.size(), 2);
1083  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1084  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1085  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1086  EXPECT_EQ(x.m_target, true);
1087  EXPECT_EQ(x.m_isSignal, true);
1088 
1089  x.loadEvent(7);
1090  EXPECT_EQ(x.m_input.size(), 4);
1091  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1092  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1093  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1094  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1095  EXPECT_EQ(x.m_spectators.size(), 2);
1096  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1097  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1098  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1099  EXPECT_EQ(x.m_target, true);
1100  EXPECT_EQ(x.m_isSignal, true);
1101 
1102  x.loadEvent(3);
1103  EXPECT_EQ(x.m_input.size(), 4);
1104  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1105  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1106  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1107  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1108  EXPECT_EQ(x.m_spectators.size(), 2);
1109  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1110  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1111  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1112  EXPECT_EQ(x.m_target, false);
1113  EXPECT_EQ(x.m_isSignal, false);
1114 
1115  x.loadEvent(8);
1116  EXPECT_EQ(x.m_input.size(), 4);
1117  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1118  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1119  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1120  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1121  EXPECT_EQ(x.m_spectators.size(), 2);
1122  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1123  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1124  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1125  EXPECT_EQ(x.m_target, false);
1126  EXPECT_EQ(x.m_isSignal, false);
1127 
1128  x.loadEvent(4);
1129  EXPECT_EQ(x.m_input.size(), 4);
1130  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1131  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1132  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1133  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1134  EXPECT_EQ(x.m_spectators.size(), 2);
1135  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1136  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1137  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1138  EXPECT_EQ(x.m_target, true);
1139  EXPECT_EQ(x.m_isSignal, true);
1140 
1141  x.loadEvent(9);
1142  EXPECT_EQ(x.m_input.size(), 4);
1143  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1144  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1145  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1146  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1147  EXPECT_EQ(x.m_spectators.size(), 2);
1148  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1149  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1150  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1151  EXPECT_EQ(x.m_target, true);
1152  EXPECT_EQ(x.m_isSignal, true);
1153 
1154  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1155 
1156  auto feature = x.getFeature(1);
1157  EXPECT_EQ(feature.size(), 10);
1158  EXPECT_FLOAT_EQ(feature[0], 1.1);
1159  EXPECT_FLOAT_EQ(feature[1], 2.1);
1160  EXPECT_FLOAT_EQ(feature[2], 3.1);
1161  EXPECT_FLOAT_EQ(feature[3], 4.1);
1162  EXPECT_FLOAT_EQ(feature[4], 5.1);
1163  EXPECT_FLOAT_EQ(feature[5], 1.1);
1164  EXPECT_FLOAT_EQ(feature[6], 2.1);
1165  EXPECT_FLOAT_EQ(feature[7], 3.1);
1166  EXPECT_FLOAT_EQ(feature[8], 4.1);
1167  EXPECT_FLOAT_EQ(feature[9], 5.1);
1168 
1169  // Same result for mother class implementation
1170  feature = x.Dataset::getFeature(1);
1171  EXPECT_EQ(feature.size(), 10);
1172  EXPECT_FLOAT_EQ(feature[0], 1.1);
1173  EXPECT_FLOAT_EQ(feature[1], 2.1);
1174  EXPECT_FLOAT_EQ(feature[2], 3.1);
1175  EXPECT_FLOAT_EQ(feature[3], 4.1);
1176  EXPECT_FLOAT_EQ(feature[4], 5.1);
1177  EXPECT_FLOAT_EQ(feature[5], 1.1);
1178  EXPECT_FLOAT_EQ(feature[6], 2.1);
1179  EXPECT_FLOAT_EQ(feature[7], 3.1);
1180  EXPECT_FLOAT_EQ(feature[8], 4.1);
1181  EXPECT_FLOAT_EQ(feature[9], 5.1);
1182 
1183  auto spectator = x.getSpectator(1);
1184  EXPECT_EQ(spectator.size(), 10);
1185  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1186  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1187  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1188  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1189  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1190  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1191  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1192  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1193  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1194  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1195 
1196  // Same result for mother class implementation
1197  spectator = x.Dataset::getSpectator(1);
1198  EXPECT_EQ(spectator.size(), 10);
1199  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1200  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1201  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1202  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1203  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1204  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1205  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1206  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1207  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1208  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1209 
1210  auto weights = x.getWeights();
1211  EXPECT_EQ(weights.size(), 10);
1212  EXPECT_FLOAT_EQ(weights[0], 1.2);
1213  EXPECT_FLOAT_EQ(weights[1], 2.2);
1214  EXPECT_FLOAT_EQ(weights[2], 3.2);
1215  EXPECT_FLOAT_EQ(weights[3], 4.2);
1216  EXPECT_FLOAT_EQ(weights[4], 5.2);
1217  EXPECT_FLOAT_EQ(weights[5], 1.2);
1218  EXPECT_FLOAT_EQ(weights[6], 2.2);
1219  EXPECT_FLOAT_EQ(weights[7], 3.2);
1220  EXPECT_FLOAT_EQ(weights[8], 4.2);
1221  EXPECT_FLOAT_EQ(weights[9], 5.2);
1222 
1223  // Same result for mother class implementation
1224  weights = x.Dataset::getWeights();
1225  EXPECT_EQ(weights.size(), 10);
1226  EXPECT_FLOAT_EQ(weights[0], 1.2);
1227  EXPECT_FLOAT_EQ(weights[1], 2.2);
1228  EXPECT_FLOAT_EQ(weights[2], 3.2);
1229  EXPECT_FLOAT_EQ(weights[3], 4.2);
1230  EXPECT_FLOAT_EQ(weights[4], 5.2);
1231  EXPECT_FLOAT_EQ(weights[5], 1.2);
1232  EXPECT_FLOAT_EQ(weights[6], 2.2);
1233  EXPECT_FLOAT_EQ(weights[7], 3.2);
1234  EXPECT_FLOAT_EQ(weights[8], 4.2);
1235  EXPECT_FLOAT_EQ(weights[9], 5.2);
1236 
1237  auto targets = x.getTargets();
1238  EXPECT_EQ(targets.size(), 10);
1239  EXPECT_EQ(targets[0], true);
1240  EXPECT_EQ(targets[1], false);
1241  EXPECT_EQ(targets[2], true);
1242  EXPECT_EQ(targets[3], false);
1243  EXPECT_EQ(targets[4], true);
1244  EXPECT_EQ(targets[5], true);
1245  EXPECT_EQ(targets[6], false);
1246  EXPECT_EQ(targets[7], true);
1247  EXPECT_EQ(targets[8], false);
1248  EXPECT_EQ(targets[9], true);
1249 
1250  auto signals = x.getSignals();
1251  EXPECT_EQ(signals.size(), 10);
1252  EXPECT_EQ(signals[0], true);
1253  EXPECT_EQ(signals[1], false);
1254  EXPECT_EQ(signals[2], true);
1255  EXPECT_EQ(signals[3], false);
1256  EXPECT_EQ(signals[4], true);
1257  EXPECT_EQ(signals[5], true);
1258  EXPECT_EQ(signals[6], false);
1259  EXPECT_EQ(signals[7], true);
1260  EXPECT_EQ(signals[8], false);
1261  EXPECT_EQ(signals[9], true);
1262 
1263  // Using __weight__ should work as well,
1264  // the only difference to using _weight__ instead of c is
1265  // in setBranchAddresses which avoids calling makeROOTCompatible
1266  // So we have to check the behaviour using __weight__ as well
1267  general_options.m_weight_variable = "__weight__";
1268  MVA::ROOTDataset y(general_options);
1269 
1270  weights = y.getWeights();
1271  EXPECT_EQ(weights.size(), 10);
1272  EXPECT_FLOAT_EQ(weights[0], 1.2);
1273  EXPECT_FLOAT_EQ(weights[1], 2.2);
1274  EXPECT_FLOAT_EQ(weights[2], 3.2);
1275  EXPECT_FLOAT_EQ(weights[3], 4.2);
1276  EXPECT_FLOAT_EQ(weights[4], 5.2);
1277  EXPECT_FLOAT_EQ(weights[5], 1.2);
1278  EXPECT_FLOAT_EQ(weights[6], 2.2);
1279  EXPECT_FLOAT_EQ(weights[7], 3.2);
1280  EXPECT_FLOAT_EQ(weights[8], 4.2);
1281  EXPECT_FLOAT_EQ(weights[9], 5.2);
1282 
1283  // Check TChain expansion
1284  general_options.m_datafiles = {"datafile*.root"};
1285  {
1286  MVA::ROOTDataset chain_test(general_options);
1287  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1288  }
1289  boost::filesystem::copy_file("datafile.root", "datafile3.root");
1290  {
1291  MVA::ROOTDataset chain_test(general_options);
1292  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1293  }
1294  boost::filesystem::copy_file("datafile.root", "datafile4.root");
1295  {
1296  MVA::ROOTDataset chain_test(general_options);
1297  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1298  }
1299  // Test m_max_events feature
1300  {
1301  general_options.m_max_events = 10;
1302  MVA::ROOTDataset chain_test(general_options);
1303  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1304  general_options.m_max_events = 0;
1305  }
1306 
1307  // If a file exists with the specified expansion
1308  // the file takes precedence over the expansion
1309  boost::filesystem::copy_file("datafile.root", "datafile*.root");
1310  {
1311  general_options.m_max_events = 0;
1312  MVA::ROOTDataset chain_test(general_options);
1313  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1314  }
1315 
1316  }
1317  TEST(DatasetTest, ROOTMultiDatasetDouble)
1318  {
1319 
1321  TFile file("datafile.root", "RECREATE");
1322  file.cd();
1323  TTree tree("tree", "TreeTitle");
1324  double a, b, c, d, e, f, g, v, w = 0;
1325  tree.Branch("a", &a, "a/D");
1326  tree.Branch("b", &b, "b/D");
1327  tree.Branch("c", &c, "c/D");
1328  tree.Branch("d", &d, "d/D");
1329  tree.Branch("e__bo__bc", &e, "e__bo__bc/D");
1330  tree.Branch("f__bo__bc", &f, "f__bo__bc/D");
1331  tree.Branch("g", &g, "g/D");
1332  tree.Branch("__weight__", &c, "__weight__/D");
1333  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
1334  tree.Branch("w", &w, "w/D");
1335 
1336  for (unsigned int i = 0; i < 5; ++i) {
1337  a = i + 1.0;
1338  b = i + 1.1;
1339  c = i + 1.2;
1340  d = i + 1.3;
1341  e = i + 1.4;
1342  f = i + 1.5;
1343  g = float(i % 2 == 0);
1344  w = i + 1.6;
1345  v = i + 1.7;
1346  tree.Fill();
1347  }
1348 
1349  file.Write("tree");
1350 
1351  TFile file2("datafile2.root", "RECREATE");
1352  file2.cd();
1353  TTree tree2("tree", "TreeTitle");
1354  tree2.Branch("a", &a);
1355  tree2.Branch("b", &b);
1356  tree2.Branch("c", &c);
1357  tree2.Branch("d", &d);
1358  tree2.Branch("e__bo__bc", &e);
1359  tree2.Branch("f__bo__bc", &f);
1360  tree2.Branch("g", &g);
1361  tree2.Branch("__weight__", &c);
1362  tree2.Branch("v__bo__bc", &v);
1363  tree2.Branch("w", &w);
1364 
1365  for (unsigned int i = 0; i < 5; ++i) {
1366  a = i + 1.0;
1367  b = i + 1.1;
1368  c = i + 1.2;
1369  d = i + 1.3;
1370  e = i + 1.4;
1371  f = i + 1.5;
1372  g = float(i % 2 == 0);
1373  w = i + 1.6;
1374  v = i + 1.7;
1375  tree2.Fill();
1376  }
1377 
1378  file2.Write("tree");
1379 
1380  MVA::GeneralOptions general_options;
1381  // Both names with and without makeROOTCompatible should work
1382  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1383  general_options.m_spectators = {"w", "v()"};
1384  general_options.m_signal_class = 1;
1385  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1386  general_options.m_treename = "tree";
1387  general_options.m_target_variable = "g";
1388  general_options.m_weight_variable = "c";
1389  MVA::ROOTDataset x(general_options);
1390 
1391  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1392  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1393  EXPECT_EQ(x.getNumberOfEvents(), 10);
1394 
1395  // Should just work
1396  x.loadEvent(0);
1397  EXPECT_EQ(x.m_input.size(), 4);
1398  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1399  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1400  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1401  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1402  EXPECT_EQ(x.m_spectators.size(), 2);
1403  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1404  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1405  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1406  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1407  EXPECT_EQ(x.m_isSignal, true);
1408 
1409  x.loadEvent(5);
1410  EXPECT_EQ(x.m_input.size(), 4);
1411  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1412  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1413  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1414  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1415  EXPECT_EQ(x.m_spectators.size(), 2);
1416  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1417  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1418  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1419  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1420  EXPECT_EQ(x.m_isSignal, true);
1421 
1422  x.loadEvent(1);
1423  EXPECT_EQ(x.m_input.size(), 4);
1424  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1425  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1426  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1427  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1428  EXPECT_EQ(x.m_spectators.size(), 2);
1429  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1430  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1431  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1432  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1433  EXPECT_EQ(x.m_isSignal, false);
1434 
1435  x.loadEvent(6);
1436  EXPECT_EQ(x.m_input.size(), 4);
1437  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1438  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1439  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1440  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1441  EXPECT_EQ(x.m_spectators.size(), 2);
1442  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1443  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1444  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1445  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1446  EXPECT_EQ(x.m_isSignal, false);
1447 
1448  x.loadEvent(2);
1449  EXPECT_EQ(x.m_input.size(), 4);
1450  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1451  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1452  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1453  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1454  EXPECT_EQ(x.m_spectators.size(), 2);
1455  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1456  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1457  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1458  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1459  EXPECT_EQ(x.m_isSignal, true);
1460 
1461  x.loadEvent(7);
1462  EXPECT_EQ(x.m_input.size(), 4);
1463  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1464  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1465  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1466  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1467  EXPECT_EQ(x.m_spectators.size(), 2);
1468  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1469  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1470  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1471  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1472  EXPECT_EQ(x.m_isSignal, true);
1473 
1474  x.loadEvent(3);
1475  EXPECT_EQ(x.m_input.size(), 4);
1476  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1477  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1478  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1479  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1480  EXPECT_EQ(x.m_spectators.size(), 2);
1481  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1482  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1483  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1484  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1485  EXPECT_EQ(x.m_isSignal, false);
1486 
1487  x.loadEvent(8);
1488  EXPECT_EQ(x.m_input.size(), 4);
1489  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1490  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1491  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1492  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1493  EXPECT_EQ(x.m_spectators.size(), 2);
1494  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1495  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1496  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1497  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1498  EXPECT_EQ(x.m_isSignal, false);
1499 
1500  x.loadEvent(4);
1501  EXPECT_EQ(x.m_input.size(), 4);
1502  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1503  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1504  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1505  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1506  EXPECT_EQ(x.m_spectators.size(), 2);
1507  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1508  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1509  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1510  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1511  EXPECT_EQ(x.m_isSignal, true);
1512 
1513  x.loadEvent(9);
1514  EXPECT_EQ(x.m_input.size(), 4);
1515  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1516  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1517  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1518  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1519  EXPECT_EQ(x.m_spectators.size(), 2);
1520  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1521  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1522  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1523  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1524  EXPECT_EQ(x.m_isSignal, true);
1525 
1526  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1527 
1528  auto feature = x.getFeature(1);
1529  EXPECT_EQ(feature.size(), 10);
1530  EXPECT_FLOAT_EQ(feature[0], 1.1);
1531  EXPECT_FLOAT_EQ(feature[1], 2.1);
1532  EXPECT_FLOAT_EQ(feature[2], 3.1);
1533  EXPECT_FLOAT_EQ(feature[3], 4.1);
1534  EXPECT_FLOAT_EQ(feature[4], 5.1);
1535  EXPECT_FLOAT_EQ(feature[5], 1.1);
1536  EXPECT_FLOAT_EQ(feature[6], 2.1);
1537  EXPECT_FLOAT_EQ(feature[7], 3.1);
1538  EXPECT_FLOAT_EQ(feature[8], 4.1);
1539  EXPECT_FLOAT_EQ(feature[9], 5.1);
1540 
1541  // Same result for mother class implementation
1542  feature = x.Dataset::getFeature(1);
1543  EXPECT_EQ(feature.size(), 10);
1544  EXPECT_FLOAT_EQ(feature[0], 1.1);
1545  EXPECT_FLOAT_EQ(feature[1], 2.1);
1546  EXPECT_FLOAT_EQ(feature[2], 3.1);
1547  EXPECT_FLOAT_EQ(feature[3], 4.1);
1548  EXPECT_FLOAT_EQ(feature[4], 5.1);
1549  EXPECT_FLOAT_EQ(feature[5], 1.1);
1550  EXPECT_FLOAT_EQ(feature[6], 2.1);
1551  EXPECT_FLOAT_EQ(feature[7], 3.1);
1552  EXPECT_FLOAT_EQ(feature[8], 4.1);
1553  EXPECT_FLOAT_EQ(feature[9], 5.1);
1554 
1555  auto spectator = x.getSpectator(1);
1556  EXPECT_EQ(spectator.size(), 10);
1557  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1558  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1559  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1560  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1561  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1562  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1563  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1564  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1565  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1566  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1567 
1568  // Same result for mother class implementation
1569  spectator = x.Dataset::getSpectator(1);
1570  EXPECT_EQ(spectator.size(), 10);
1571  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1572  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1573  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1574  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1575  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1576  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1577  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1578  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1579  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1580  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1581 
1582  auto weights = x.getWeights();
1583  EXPECT_EQ(weights.size(), 10);
1584  EXPECT_FLOAT_EQ(weights[0], 1.2);
1585  EXPECT_FLOAT_EQ(weights[1], 2.2);
1586  EXPECT_FLOAT_EQ(weights[2], 3.2);
1587  EXPECT_FLOAT_EQ(weights[3], 4.2);
1588  EXPECT_FLOAT_EQ(weights[4], 5.2);
1589  EXPECT_FLOAT_EQ(weights[5], 1.2);
1590  EXPECT_FLOAT_EQ(weights[6], 2.2);
1591  EXPECT_FLOAT_EQ(weights[7], 3.2);
1592  EXPECT_FLOAT_EQ(weights[8], 4.2);
1593  EXPECT_FLOAT_EQ(weights[9], 5.2);
1594 
1595  // Same result for mother class implementation
1596  weights = x.Dataset::getWeights();
1597  EXPECT_EQ(weights.size(), 10);
1598  EXPECT_FLOAT_EQ(weights[0], 1.2);
1599  EXPECT_FLOAT_EQ(weights[1], 2.2);
1600  EXPECT_FLOAT_EQ(weights[2], 3.2);
1601  EXPECT_FLOAT_EQ(weights[3], 4.2);
1602  EXPECT_FLOAT_EQ(weights[4], 5.2);
1603  EXPECT_FLOAT_EQ(weights[5], 1.2);
1604  EXPECT_FLOAT_EQ(weights[6], 2.2);
1605  EXPECT_FLOAT_EQ(weights[7], 3.2);
1606  EXPECT_FLOAT_EQ(weights[8], 4.2);
1607  EXPECT_FLOAT_EQ(weights[9], 5.2);
1608 
1609  auto targets = x.getTargets();
1610  EXPECT_EQ(targets.size(), 10);
1611  EXPECT_FLOAT_EQ(targets[0], 1.0);
1612  EXPECT_FLOAT_EQ(targets[1], 0.0);
1613  EXPECT_FLOAT_EQ(targets[2], 1.0);
1614  EXPECT_FLOAT_EQ(targets[3], 0.0);
1615  EXPECT_FLOAT_EQ(targets[4], 1.0);
1616  EXPECT_FLOAT_EQ(targets[5], 1.0);
1617  EXPECT_FLOAT_EQ(targets[6], 0.0);
1618  EXPECT_FLOAT_EQ(targets[7], 1.0);
1619  EXPECT_FLOAT_EQ(targets[8], 0.0);
1620  EXPECT_FLOAT_EQ(targets[9], 1.0);
1621 
1622  auto signals = x.getSignals();
1623  EXPECT_EQ(signals.size(), 10);
1624  EXPECT_EQ(signals[0], true);
1625  EXPECT_EQ(signals[1], false);
1626  EXPECT_EQ(signals[2], true);
1627  EXPECT_EQ(signals[3], false);
1628  EXPECT_EQ(signals[4], true);
1629  EXPECT_EQ(signals[5], true);
1630  EXPECT_EQ(signals[6], false);
1631  EXPECT_EQ(signals[7], true);
1632  EXPECT_EQ(signals[8], false);
1633  EXPECT_EQ(signals[9], true);
1634 
1635  // Using __weight__ should work as well,
1636  // the only difference to using _weight__ instead of g is
1637  // in setBranchAddresses which avoids calling makeROOTCompatible
1638  // So we have to check the behaviour using __weight__ as well
1639  general_options.m_weight_variable = "__weight__";
1640  MVA::ROOTDataset y(general_options);
1641 
1642  weights = y.getWeights();
1643  EXPECT_EQ(weights.size(), 10);
1644  EXPECT_FLOAT_EQ(weights[0], 1.2);
1645  EXPECT_FLOAT_EQ(weights[1], 2.2);
1646  EXPECT_FLOAT_EQ(weights[2], 3.2);
1647  EXPECT_FLOAT_EQ(weights[3], 4.2);
1648  EXPECT_FLOAT_EQ(weights[4], 5.2);
1649  EXPECT_FLOAT_EQ(weights[5], 1.2);
1650  EXPECT_FLOAT_EQ(weights[6], 2.2);
1651  EXPECT_FLOAT_EQ(weights[7], 3.2);
1652  EXPECT_FLOAT_EQ(weights[8], 4.2);
1653  EXPECT_FLOAT_EQ(weights[9], 5.2);
1654 
1655  // Check TChain expansion
1656  general_options.m_datafiles = {"datafile*.root"};
1657  {
1658  MVA::ROOTDataset chain_test(general_options);
1659  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1660  }
1661  boost::filesystem::copy_file("datafile.root", "datafile3.root");
1662  {
1663  MVA::ROOTDataset chain_test(general_options);
1664  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1665  }
1666  boost::filesystem::copy_file("datafile.root", "datafile4.root");
1667  {
1668  MVA::ROOTDataset chain_test(general_options);
1669  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1670  }
1671  // Test m_max_events feature
1672  {
1673  general_options.m_max_events = 10;
1674  MVA::ROOTDataset chain_test(general_options);
1675  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1676  general_options.m_max_events = 0;
1677  }
1678 
1679  // If a file exists with the specified expansion
1680  // the file takes precedence over the expansion
1681  boost::filesystem::copy_file("datafile.root", "datafile*.root");
1682  {
1683  general_options.m_max_events = 0;
1684  MVA::ROOTDataset chain_test(general_options);
1685  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1686  }
1687  }
1688 }
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Definition: Dataset.h:294
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
Definition: Dataset.h:349
Wraps the data of a single event into a Dataset.
Definition: Dataset.h:135
Wraps another Dataset and provides a view to a subset of its features and events.
Definition: Dataset.h:234
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:66
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:23