Belle II Software  light-2212-foldex
test_Dataset.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/interface/Dataset.h>
10 #include <framework/utilities/TestHelpers.h>
11 
12 #include <boost/filesystem/operations.hpp>
13 
14 #include <gtest/gtest.h>
15 
16 #include <fstream>
17 #include <numeric>
18 
19 using namespace Belle2;
20 
21 namespace {
22 
23  class TestDataset : public MVA::Dataset {
24  public:
25  explicit TestDataset(MVA::GeneralOptions& general_options) : MVA::Dataset(general_options)
26  {
27  m_input = {1.0, 2.0, 3.0, 4.0, 5.0};
28  m_target = 3.0;
29  m_isSignal = true;
30  m_weight = -3.0;
31  }
32 
33  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 5; }
34  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 2; }
35  [[nodiscard]] unsigned int getNumberOfEvents() const override { return 20; }
36  void loadEvent(unsigned int iEvent) override
37  {
38  auto f = static_cast<float>(iEvent);
39  m_input = {f + 1, f + 2, f + 3, f + 4, f + 5};
40  m_spectators = {f + 6, f + 7};
41  };
42  float getSignalFraction() override { return 0.1; };
43  std::vector<float> getFeature(unsigned int iFeature) override
44  {
45  std::vector<float> a(20, 0.0);
46  std::iota(a.begin(), a.end(), iFeature + 1);
47  return a;
48  }
49  std::vector<float> getSpectator(unsigned int iSpectator) override
50  {
51  std::vector<float> a(20, 0.0);
52  std::iota(a.begin(), a.end(), iSpectator + 6);
53  return a;
54  }
55 
56  };
57 
58  TEST(DatasetTest, SingleDataset)
59  {
60 
61  MVA::GeneralOptions general_options;
62  general_options.m_variables = {"a", "b", "c"};
63  general_options.m_spectators = {"e", "f"};
64  general_options.m_signal_class = 2;
65  std::vector<float> input = {1.0, 2.0, 3.0};
66  std::vector<float> spectators = {4.0, 5.0};
67 
68  MVA::SingleDataset x(general_options, input, 2.0, spectators);
69 
70  EXPECT_EQ(x.getNumberOfFeatures(), 3);
71  EXPECT_EQ(x.getNumberOfSpectators(), 2);
72  EXPECT_EQ(x.getNumberOfEvents(), 1);
73 
74  EXPECT_EQ(x.getFeatureIndex("a"), 0);
75  EXPECT_EQ(x.getFeatureIndex("b"), 1);
76  EXPECT_EQ(x.getFeatureIndex("c"), 2);
77  EXPECT_B2ERROR(x.getFeatureIndex("bla"));
78 
79  EXPECT_EQ(x.getSpectatorIndex("e"), 0);
80  EXPECT_EQ(x.getSpectatorIndex("f"), 1);
81  EXPECT_B2ERROR(x.getSpectatorIndex("bla"));
82 
83  // Should just work
84  x.loadEvent(0);
85 
86  EXPECT_EQ(x.m_input.size(), 3);
87  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
88  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
89  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
90 
91  EXPECT_EQ(x.m_spectators.size(), 2);
92  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.0);
93  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.0);
94 
95  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
96  EXPECT_FLOAT_EQ(x.m_target, 2.0);
97  EXPECT_EQ(x.m_isSignal, true);
98 
99  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1.0);
100 
101  auto feature = x.getFeature(1);
102  EXPECT_EQ(feature.size(), 1);
103  EXPECT_FLOAT_EQ(feature[0], 2.0);
104 
105  auto spectator = x.getSpectator(1);
106  EXPECT_EQ(spectator.size(), 1);
107  EXPECT_FLOAT_EQ(spectator[0], 5.0);
108 
109  // Same result for mother class implementation
110  feature = x.Dataset::getFeature(1);
111  EXPECT_EQ(feature.size(), 1);
112  EXPECT_FLOAT_EQ(feature[0], 2.0);
113 
114  }
115 
116  TEST(DatasetTest, MultiDataset)
117  {
118 
119  MVA::GeneralOptions general_options;
120  general_options.m_variables = {"a", "b", "c"};
121  general_options.m_spectators = {"e", "f"};
122  general_options.m_signal_class = 2;
123  std::vector<std::vector<float>> matrix = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}};
124  std::vector<std::vector<float>> spectator_matrix = {{12.0, 13.0}, {15.0, 16.0}, {18.0, 19.0}};
125  std::vector<float> targets = {2.0, 0.0, 2.0};
126  std::vector<float> weights = {1.0, 2.0, 3.0};
127 
128  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, {1.0}, weights));
129 
130  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, targets, {1.0}));
131 
132  MVA::MultiDataset x(general_options, matrix, spectator_matrix, targets, weights);
133 
134  EXPECT_EQ(x.getNumberOfFeatures(), 3);
135  EXPECT_EQ(x.getNumberOfEvents(), 3);
136 
137  // Should just work
138  x.loadEvent(0);
139 
140  EXPECT_EQ(x.m_input.size(), 3);
141  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
142  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
143  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
144 
145  EXPECT_EQ(x.m_spectators.size(), 2);
146  EXPECT_FLOAT_EQ(x.m_spectators[0], 12.0);
147  EXPECT_FLOAT_EQ(x.m_spectators[1], 13.0);
148 
149  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
150  EXPECT_FLOAT_EQ(x.m_target, 2.0);
151  EXPECT_EQ(x.m_isSignal, true);
152 
153  // Should just work
154  x.loadEvent(1);
155 
156  EXPECT_EQ(x.m_input.size(), 3);
157  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
158  EXPECT_FLOAT_EQ(x.m_input[1], 5.0);
159  EXPECT_FLOAT_EQ(x.m_input[2], 6.0);
160 
161  EXPECT_EQ(x.m_spectators.size(), 2);
162  EXPECT_FLOAT_EQ(x.m_spectators[0], 15.0);
163  EXPECT_FLOAT_EQ(x.m_spectators[1], 16.0);
164 
165  EXPECT_FLOAT_EQ(x.m_weight, 2.0);
166  EXPECT_FLOAT_EQ(x.m_target, 0.0);
167  EXPECT_EQ(x.m_isSignal, false);
168 
169  // Should just work
170  x.loadEvent(2);
171 
172  EXPECT_EQ(x.m_input.size(), 3);
173  EXPECT_FLOAT_EQ(x.m_input[0], 7.0);
174  EXPECT_FLOAT_EQ(x.m_input[1], 8.0);
175  EXPECT_FLOAT_EQ(x.m_input[2], 9.0);
176 
177  EXPECT_EQ(x.m_spectators.size(), 2);
178  EXPECT_FLOAT_EQ(x.m_spectators[0], 18.0);
179  EXPECT_FLOAT_EQ(x.m_spectators[1], 19.0);
180 
181  EXPECT_FLOAT_EQ(x.m_weight, 3.0);
182  EXPECT_FLOAT_EQ(x.m_target, 2.0);
183  EXPECT_EQ(x.m_isSignal, true);
184 
185  EXPECT_FLOAT_EQ(x.getSignalFraction(), 4.0 / 6.0);
186 
187  auto feature = x.getFeature(1);
188  EXPECT_EQ(feature.size(), 3);
189  EXPECT_FLOAT_EQ(feature[0], 2.0);
190  EXPECT_FLOAT_EQ(feature[1], 5.0);
191  EXPECT_FLOAT_EQ(feature[2], 8.0);
192 
193  auto spectator = x.getSpectator(1);
194  EXPECT_EQ(spectator.size(), 3);
195  EXPECT_FLOAT_EQ(spectator[0], 13.0);
196  EXPECT_FLOAT_EQ(spectator[1], 16.0);
197  EXPECT_FLOAT_EQ(spectator[2], 19.0);
198 
199 
200  }
201 
202  TEST(DatasetTest, SubDataset)
203  {
204 
205  MVA::GeneralOptions general_options;
206  general_options.m_signal_class = 3;
207  general_options.m_variables = {"a", "b", "c", "d", "e"};
208  general_options.m_spectators = {"f", "g"};
209  TestDataset test_dataset(general_options);
210 
211  general_options.m_variables = {"a", "d", "e"};
212  general_options.m_spectators = {"g"};
213  std::vector<bool> events = {true, false, true, false, true, false, true, false, true, false,
214  true, false, true, false, true, false, true, false, true, false
215  };
216  MVA::SubDataset x(general_options, events, test_dataset);
217 
218  EXPECT_EQ(x.getNumberOfFeatures(), 3);
219  EXPECT_EQ(x.getNumberOfEvents(), 10);
220 
221  // Should just work
222  x.loadEvent(0);
223 
224  EXPECT_EQ(x.m_input.size(), 3);
225  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
226  EXPECT_FLOAT_EQ(x.m_input[1], 4.0);
227  EXPECT_FLOAT_EQ(x.m_input[2], 5.0);
228 
229  EXPECT_EQ(x.m_spectators.size(), 1);
230  EXPECT_FLOAT_EQ(x.m_spectators[0], 7.0);
231 
232  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
233  EXPECT_FLOAT_EQ(x.m_target, 3.0);
234  EXPECT_EQ(x.m_isSignal, true);
235 
236  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1);
237 
238  auto feature = x.getFeature(1);
239  EXPECT_EQ(feature.size(), 10);
240  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
241  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
242  };
243 
244  auto spectator = x.getSpectator(0);
245  EXPECT_EQ(spectator.size(), 10);
246  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
247  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
248  };
249 
250  // Same result for mother class implementation
251  feature = x.Dataset::getFeature(1);
252  EXPECT_EQ(feature.size(), 10);
253  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
254  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
255  };
256 
257  spectator = x.Dataset::getSpectator(0);
258  EXPECT_EQ(spectator.size(), 10);
259  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
260  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
261  };
262 
263  // Test without event indices
264  MVA::SubDataset y(general_options, {}, test_dataset);
265  feature = y.getFeature(1);
266  EXPECT_EQ(feature.size(), 20);
267  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
268  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
269  };
270 
271  spectator = y.getSpectator(0);
272  EXPECT_EQ(spectator.size(), 20);
273  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
274  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
275  };
276 
277  // Same result for mother class implementation
278  feature = y.Dataset::getFeature(1);
279  EXPECT_EQ(feature.size(), 20);
280  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
281  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
282  };
283 
284  spectator = y.Dataset::getSpectator(0);
285  EXPECT_EQ(spectator.size(), 20);
286  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
287  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
288  };
289 
290  general_options.m_variables = {"a", "d", "e", "DOESNOTEXIST"};
291  try {
292  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
293  } catch (...) {
294 
295  }
296  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
297 
298  general_options.m_variables = {"a", "d", "e"};
299  general_options.m_spectators = {"DOESNOTEXIST"};
300  try {
301  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
302  } catch (...) {
303 
304  }
305  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
306 
307  }
308 
309  TEST(DatasetTest, CombinedDataset)
310  {
311 
312  MVA::GeneralOptions general_options;
313  general_options.m_signal_class = 1;
314  general_options.m_variables = {"a", "b", "c", "d", "e"};
315  general_options.m_spectators = {"f", "g"};
316  TestDataset signal_dataset(general_options);
317  TestDataset bckgrd_dataset(general_options);
318 
319  MVA::CombinedDataset x(general_options, signal_dataset, bckgrd_dataset);
320 
321  EXPECT_EQ(x.getNumberOfFeatures(), 5);
322  EXPECT_EQ(x.getNumberOfEvents(), 40);
323 
324  // Should just work
325  x.loadEvent(0);
326 
327  EXPECT_EQ(x.m_input.size(), 5);
328  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
329  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
330  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
331  EXPECT_FLOAT_EQ(x.m_input[3], 4.0);
332  EXPECT_FLOAT_EQ(x.m_input[4], 5.0);
333 
334  EXPECT_EQ(x.m_spectators.size(), 2);
335  EXPECT_FLOAT_EQ(x.m_spectators[0], 6.0);
336  EXPECT_FLOAT_EQ(x.m_spectators[1], 7.0);
337 
338  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
339  EXPECT_FLOAT_EQ(x.m_target, 1.0);
340  EXPECT_EQ(x.m_isSignal, true);
341 
342  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.5);
343 
344  auto feature = x.getFeature(1);
345  EXPECT_EQ(feature.size(), 40);
346  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
347  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
348  };
349 
350  auto spectator = x.getSpectator(0);
351  EXPECT_EQ(spectator.size(), 40);
352  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
353  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
354  };
355 
356  // Same result for mother class implementation
357  feature = x.Dataset::getFeature(1);
358  EXPECT_EQ(feature.size(), 40);
359  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
360  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
361  };
362 
363  spectator = x.Dataset::getSpectator(0);
364  EXPECT_EQ(spectator.size(), 40);
365  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
366  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
367  };
368 
369  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
370  x.loadEvent(iEvent);
371  EXPECT_EQ(x.m_isSignal, iEvent < 20);
372  EXPECT_FLOAT_EQ(x.m_target, iEvent < 20 ? 1.0 : 0.0);
373  }
374 
375  }
376 
377  TEST(DatasetTest, ROOTDataset)
378  {
379 
381  TFile file("datafile.root", "RECREATE");
382  file.cd();
383  TTree tree("tree", "TreeTitle");
384  float a, b, c, d, e, f, v, w = 0;
385  bool target;
386  tree.Branch("a", &a);
387  tree.Branch("b", &b);
388  tree.Branch("c", &c);
389  tree.Branch("d", &d);
390  tree.Branch("e__bo__bc", &e);
391  tree.Branch("f__bo__bc", &f);
392  tree.Branch("g", &target);
393  tree.Branch("__weight__", &c);
394  tree.Branch("v__bo__bc", &v);
395  tree.Branch("w", &w);
396 
397  for (unsigned int i = 0; i < 5; ++i) {
398  a = i + 1.0;
399  b = i + 1.1;
400  c = i + 1.2;
401  d = i + 1.3;
402  e = i + 1.4;
403  f = i + 1.5;
404  target = (i % 2 == 0);
405  w = i + 1.6;
406  v = i + 1.7;
407  tree.Fill();
408  }
409 
410  file.Write("tree");
411 
412  MVA::GeneralOptions general_options;
413  // Both names with and without makeROOTCompatible should work
414  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
415  general_options.m_spectators = {"w", "v()"};
416  general_options.m_signal_class = 1;
417  general_options.m_datafiles = {"datafile.root"};
418  general_options.m_treename = "tree";
419  general_options.m_target_variable = "g";
420  general_options.m_weight_variable = "c";
421  MVA::ROOTDataset x(general_options);
422 
423  EXPECT_EQ(x.getNumberOfFeatures(), 4);
424  EXPECT_EQ(x.getNumberOfSpectators(), 2);
425  EXPECT_EQ(x.getNumberOfEvents(), 5);
426 
427  // Should just work
428  x.loadEvent(0);
429  EXPECT_EQ(x.m_input.size(), 4);
430  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
431  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
432  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
433  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
434  EXPECT_EQ(x.m_spectators.size(), 2);
435  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
436  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
437  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
438  EXPECT_EQ(x.m_target, true);
439  EXPECT_EQ(x.m_isSignal, true);
440 
441  x.loadEvent(1);
442  EXPECT_EQ(x.m_input.size(), 4);
443  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
444  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
445  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
446  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
447  EXPECT_EQ(x.m_spectators.size(), 2);
448  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
449  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
450  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
451  EXPECT_EQ(x.m_target, false);
452  EXPECT_EQ(x.m_isSignal, false);
453 
454  x.loadEvent(2);
455  EXPECT_EQ(x.m_input.size(), 4);
456  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
457  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
458  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
459  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
460  EXPECT_EQ(x.m_spectators.size(), 2);
461  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
462  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
463  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
464  EXPECT_EQ(x.m_target, true);
465  EXPECT_EQ(x.m_isSignal, true);
466 
467  x.loadEvent(3);
468  EXPECT_EQ(x.m_input.size(), 4);
469  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
470  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
471  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
472  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
473  EXPECT_EQ(x.m_spectators.size(), 2);
474  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
475  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
476  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
477  EXPECT_EQ(x.m_target, false);
478  EXPECT_EQ(x.m_isSignal, false);
479 
480  x.loadEvent(4);
481  EXPECT_EQ(x.m_input.size(), 4);
482  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
483  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
484  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
485  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
486  EXPECT_EQ(x.m_spectators.size(), 2);
487  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
488  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
489  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
490  EXPECT_EQ(x.m_target, true);
491  EXPECT_EQ(x.m_isSignal, true);
492 
493  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
494 
495  auto feature = x.getFeature(1);
496  EXPECT_EQ(feature.size(), 5);
497  EXPECT_FLOAT_EQ(feature[0], 1.1);
498  EXPECT_FLOAT_EQ(feature[1], 2.1);
499  EXPECT_FLOAT_EQ(feature[2], 3.1);
500  EXPECT_FLOAT_EQ(feature[3], 4.1);
501  EXPECT_FLOAT_EQ(feature[4], 5.1);
502 
503  // Same result for mother class implementation
504  feature = x.Dataset::getFeature(1);
505  EXPECT_EQ(feature.size(), 5);
506  EXPECT_FLOAT_EQ(feature[0], 1.1);
507  EXPECT_FLOAT_EQ(feature[1], 2.1);
508  EXPECT_FLOAT_EQ(feature[2], 3.1);
509  EXPECT_FLOAT_EQ(feature[3], 4.1);
510  EXPECT_FLOAT_EQ(feature[4], 5.1);
511 
512  auto spectator = x.getSpectator(1);
513  EXPECT_EQ(spectator.size(), 5);
514  EXPECT_FLOAT_EQ(spectator[0], 1.7);
515  EXPECT_FLOAT_EQ(spectator[1], 2.7);
516  EXPECT_FLOAT_EQ(spectator[2], 3.7);
517  EXPECT_FLOAT_EQ(spectator[3], 4.7);
518  EXPECT_FLOAT_EQ(spectator[4], 5.7);
519 
520  // Same result for mother class implementation
521  spectator = x.Dataset::getSpectator(1);
522  EXPECT_EQ(spectator.size(), 5);
523  EXPECT_FLOAT_EQ(spectator[0], 1.7);
524  EXPECT_FLOAT_EQ(spectator[1], 2.7);
525  EXPECT_FLOAT_EQ(spectator[2], 3.7);
526  EXPECT_FLOAT_EQ(spectator[3], 4.7);
527  EXPECT_FLOAT_EQ(spectator[4], 5.7);
528 
529  auto weights = x.getWeights();
530  EXPECT_EQ(weights.size(), 5);
531  EXPECT_FLOAT_EQ(weights[0], 1.2);
532  EXPECT_FLOAT_EQ(weights[1], 2.2);
533  EXPECT_FLOAT_EQ(weights[2], 3.2);
534  EXPECT_FLOAT_EQ(weights[3], 4.2);
535  EXPECT_FLOAT_EQ(weights[4], 5.2);
536 
537  // Same result for mother class implementation
538  weights = x.Dataset::getWeights();
539  EXPECT_EQ(weights.size(), 5);
540  EXPECT_FLOAT_EQ(weights[0], 1.2);
541  EXPECT_FLOAT_EQ(weights[1], 2.2);
542  EXPECT_FLOAT_EQ(weights[2], 3.2);
543  EXPECT_FLOAT_EQ(weights[3], 4.2);
544  EXPECT_FLOAT_EQ(weights[4], 5.2);
545 
546  auto targets = x.getTargets();
547  EXPECT_EQ(targets.size(), 5);
548  EXPECT_EQ(targets[0], true);
549  EXPECT_EQ(targets[1], false);
550  EXPECT_EQ(targets[2], true);
551  EXPECT_EQ(targets[3], false);
552  EXPECT_EQ(targets[4], true);
553 
554  auto signals = x.getSignals();
555  EXPECT_EQ(signals.size(), 5);
556  EXPECT_EQ(signals[0], true);
557  EXPECT_EQ(signals[1], false);
558  EXPECT_EQ(signals[2], true);
559  EXPECT_EQ(signals[3], false);
560  EXPECT_EQ(signals[4], true);
561 
562  // Using __weight__ should work as well,
563  // the only difference to using _weight__ instead of c is
564  // in setBranchAddresses which avoids calling makeROOTCompatible
565  // So we have to check the behaviour using __weight__ as well
566  general_options.m_weight_variable = "__weight__";
567  MVA::ROOTDataset y(general_options);
568 
569  weights = y.getWeights();
570  EXPECT_EQ(weights.size(), 5);
571  EXPECT_FLOAT_EQ(weights[0], 1.2);
572  EXPECT_FLOAT_EQ(weights[1], 2.2);
573  EXPECT_FLOAT_EQ(weights[2], 3.2);
574  EXPECT_FLOAT_EQ(weights[3], 4.2);
575  EXPECT_FLOAT_EQ(weights[4], 5.2);
576 
577  // Check TChain expansion
578  general_options.m_datafiles = {"datafile*.root"};
579  {
580  MVA::ROOTDataset chain_test(general_options);
581  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
582  }
583  boost::filesystem::copy_file("datafile.root", "datafile2.root");
584  {
585  MVA::ROOTDataset chain_test(general_options);
586  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
587  }
588  boost::filesystem::copy_file("datafile.root", "datafile3.root");
589  {
590  MVA::ROOTDataset chain_test(general_options);
591  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
592  }
593  // Test m_max_events feature
594  {
595  general_options.m_max_events = 10;
596  MVA::ROOTDataset chain_test(general_options);
597  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
598  general_options.m_max_events = 0;
599  }
600 
601  // Check for missing tree
602  general_options.m_treename = "missing tree";
603  try {
604  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
605  } catch (...) {
606 
607  }
608  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
609 
610  // Check for missing branch
611  general_options.m_treename = "tree";
612  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
613  try {
614  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
615  } catch (...) {
616 
617  }
618  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
619 
620  // Check for missing branch
621  general_options.m_treename = "tree";
622  general_options.m_variables = {"a", "b", "e", "f"};
623  general_options.m_spectators = {"missing branch"};
624  try {
625  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
626  } catch (...) {
627 
628  }
629  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
630 
631  // Check for missing file
632  general_options.m_spectators = {};
633  general_options.m_datafiles = {"DOESNOTEXIST.root"};
634  general_options.m_treename = "tree";
635  try {
636  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
637  } catch (...) {
638 
639  }
640  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
641 
642  // Check for invalid file
643  general_options.m_datafiles = {"ISNotAValidROOTFile"};
644  general_options.m_treename = "tree";
645 
646  {
647  std::ofstream(general_options.m_datafiles[0]);
648  }
649  EXPECT_TRUE(boost::filesystem::exists(general_options.m_datafiles[0]));
650 
651  try {
652  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
653  } catch (...) {
654 
655  }
656  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
657  }
658 
659 
660  TEST(DatasetTest, ROOTDatasetDouble)
661  {
662 
664  TFile file("datafile.root", "RECREATE");
665  file.cd();
666  TTree tree("tree", "TreeTitle");
667  double a, b, c, d, e, f, g, v, w = 0;
668  tree.Branch("a", &a, "a/D");
669  tree.Branch("b", &b, "b/D");
670  tree.Branch("c", &c, "c/D");
671  tree.Branch("d", &d, "d/D");
672  tree.Branch("e__bo__bc", &e, "e__bo__bc/D");
673  tree.Branch("f__bo__bc", &f, "f__bo__bc/D");
674  tree.Branch("g", &g, "g/D");
675  tree.Branch("__weight__", &c, "__weight__/D");
676  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
677  tree.Branch("w", &w, "w/D");
678 
679  for (unsigned int i = 0; i < 5; ++i) {
680  a = i + 1.0;
681  b = i + 1.1;
682  c = i + 1.2;
683  d = i + 1.3;
684  e = i + 1.4;
685  f = i + 1.5;
686  g = float(i % 2 == 0);
687  w = i + 1.6;
688  v = i + 1.7;
689  tree.Fill();
690  }
691 
692  file.Write("tree");
693 
694  MVA::GeneralOptions general_options;
695  // Both names with and without makeROOTCompatible should work
696  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
697  general_options.m_spectators = {"w", "v()"};
698  general_options.m_signal_class = 1;
699  general_options.m_datafiles = {"datafile.root"};
700  general_options.m_treename = "tree";
701  general_options.m_target_variable = "g";
702  general_options.m_weight_variable = "c";
703  MVA::ROOTDataset x(general_options);
704 
705  EXPECT_EQ(x.getNumberOfFeatures(), 4);
706  EXPECT_EQ(x.getNumberOfSpectators(), 2);
707  EXPECT_EQ(x.getNumberOfEvents(), 5);
708 
709  // Should just work
710  x.loadEvent(0);
711  EXPECT_EQ(x.m_input.size(), 4);
712  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
713  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
714  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
715  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
716  EXPECT_EQ(x.m_spectators.size(), 2);
717  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
718  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
719  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
720  EXPECT_FLOAT_EQ(x.m_target, 1.0);
721  EXPECT_EQ(x.m_isSignal, true);
722 
723  x.loadEvent(1);
724  EXPECT_EQ(x.m_input.size(), 4);
725  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
726  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
727  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
728  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
729  EXPECT_EQ(x.m_spectators.size(), 2);
730  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
731  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
732  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
733  EXPECT_FLOAT_EQ(x.m_target, 0.0);
734  EXPECT_EQ(x.m_isSignal, false);
735 
736  x.loadEvent(2);
737  EXPECT_EQ(x.m_input.size(), 4);
738  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
739  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
740  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
741  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
742  EXPECT_EQ(x.m_spectators.size(), 2);
743  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
744  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
745  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
746  EXPECT_FLOAT_EQ(x.m_target, 1.0);
747  EXPECT_EQ(x.m_isSignal, true);
748 
749  x.loadEvent(3);
750  EXPECT_EQ(x.m_input.size(), 4);
751  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
752  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
753  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
754  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
755  EXPECT_EQ(x.m_spectators.size(), 2);
756  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
757  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
758  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
759  EXPECT_FLOAT_EQ(x.m_target, 0.0);
760  EXPECT_EQ(x.m_isSignal, false);
761 
762  x.loadEvent(4);
763  EXPECT_EQ(x.m_input.size(), 4);
764  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
765  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
766  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
767  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
768  EXPECT_EQ(x.m_spectators.size(), 2);
769  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
770  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
771  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
772  EXPECT_FLOAT_EQ(x.m_target, 1.0);
773  EXPECT_EQ(x.m_isSignal, true);
774 
775  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
776 
777  auto feature = x.getFeature(1);
778  EXPECT_EQ(feature.size(), 5);
779  EXPECT_FLOAT_EQ(feature[0], 1.1);
780  EXPECT_FLOAT_EQ(feature[1], 2.1);
781  EXPECT_FLOAT_EQ(feature[2], 3.1);
782  EXPECT_FLOAT_EQ(feature[3], 4.1);
783  EXPECT_FLOAT_EQ(feature[4], 5.1);
784 
785  // Same result for mother class implementation
786  feature = x.Dataset::getFeature(1);
787  EXPECT_EQ(feature.size(), 5);
788  EXPECT_FLOAT_EQ(feature[0], 1.1);
789  EXPECT_FLOAT_EQ(feature[1], 2.1);
790  EXPECT_FLOAT_EQ(feature[2], 3.1);
791  EXPECT_FLOAT_EQ(feature[3], 4.1);
792  EXPECT_FLOAT_EQ(feature[4], 5.1);
793 
794  auto spectator = x.getSpectator(1);
795  EXPECT_EQ(spectator.size(), 5);
796  EXPECT_FLOAT_EQ(spectator[0], 1.7);
797  EXPECT_FLOAT_EQ(spectator[1], 2.7);
798  EXPECT_FLOAT_EQ(spectator[2], 3.7);
799  EXPECT_FLOAT_EQ(spectator[3], 4.7);
800  EXPECT_FLOAT_EQ(spectator[4], 5.7);
801 
802  // Same result for mother class implementation
803  spectator = x.Dataset::getSpectator(1);
804  EXPECT_EQ(spectator.size(), 5);
805  EXPECT_FLOAT_EQ(spectator[0], 1.7);
806  EXPECT_FLOAT_EQ(spectator[1], 2.7);
807  EXPECT_FLOAT_EQ(spectator[2], 3.7);
808  EXPECT_FLOAT_EQ(spectator[3], 4.7);
809  EXPECT_FLOAT_EQ(spectator[4], 5.7);
810 
811  auto weights = x.getWeights();
812  EXPECT_EQ(weights.size(), 5);
813  EXPECT_FLOAT_EQ(weights[0], 1.2);
814  EXPECT_FLOAT_EQ(weights[1], 2.2);
815  EXPECT_FLOAT_EQ(weights[2], 3.2);
816  EXPECT_FLOAT_EQ(weights[3], 4.2);
817  EXPECT_FLOAT_EQ(weights[4], 5.2);
818 
819  // Same result for mother class implementation
820  weights = x.Dataset::getWeights();
821  EXPECT_EQ(weights.size(), 5);
822  EXPECT_FLOAT_EQ(weights[0], 1.2);
823  EXPECT_FLOAT_EQ(weights[1], 2.2);
824  EXPECT_FLOAT_EQ(weights[2], 3.2);
825  EXPECT_FLOAT_EQ(weights[3], 4.2);
826  EXPECT_FLOAT_EQ(weights[4], 5.2);
827 
828  auto targets = x.getTargets();
829  EXPECT_EQ(targets.size(), 5);
830  EXPECT_FLOAT_EQ(targets[0], 1.0);
831  EXPECT_FLOAT_EQ(targets[1], 0.0);
832  EXPECT_FLOAT_EQ(targets[2], 1.0);
833  EXPECT_FLOAT_EQ(targets[3], 0.0);
834  EXPECT_FLOAT_EQ(targets[4], 1.0);
835 
836  auto signals = x.getSignals();
837  EXPECT_EQ(signals.size(), 5);
838  EXPECT_EQ(signals[0], true);
839  EXPECT_EQ(signals[1], false);
840  EXPECT_EQ(signals[2], true);
841  EXPECT_EQ(signals[3], false);
842  EXPECT_EQ(signals[4], true);
843 
844  // Using __weight__ should work as well,
845  // the only difference to using _weight__ instead of g is
846  // in setBranchAddresses which avoids calling makeROOTCompatible
847  // So we have to check the behaviour using __weight__ as well
848  general_options.m_weight_variable = "__weight__";
849  MVA::ROOTDataset y(general_options);
850 
851  weights = y.getWeights();
852  EXPECT_EQ(weights.size(), 5);
853  EXPECT_FLOAT_EQ(weights[0], 1.2);
854  EXPECT_FLOAT_EQ(weights[1], 2.2);
855  EXPECT_FLOAT_EQ(weights[2], 3.2);
856  EXPECT_FLOAT_EQ(weights[3], 4.2);
857  EXPECT_FLOAT_EQ(weights[4], 5.2);
858 
859  // Check TChain expansion
860  general_options.m_datafiles = {"datafile*.root"};
861  {
862  MVA::ROOTDataset chain_test(general_options);
863  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
864  }
865  boost::filesystem::copy_file("datafile.root", "datafile2.root");
866  {
867  MVA::ROOTDataset chain_test(general_options);
868  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
869  }
870  boost::filesystem::copy_file("datafile.root", "datafile3.root");
871  {
872  MVA::ROOTDataset chain_test(general_options);
873  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
874  }
875  // Test m_max_events feature
876  {
877  general_options.m_max_events = 10;
878  MVA::ROOTDataset chain_test(general_options);
879  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
880  general_options.m_max_events = 0;
881  }
882 
883  // Check for missing tree
884  general_options.m_treename = "missing tree";
885  try {
886  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
887  } catch (...) {
888 
889  }
890  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
891 
892  // Check for missing branch
893  general_options.m_treename = "tree";
894  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
895  try {
896  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
897  } catch (...) {
898 
899  }
900  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
901 
902  // Check for missing branch
903  general_options.m_treename = "tree";
904  general_options.m_variables = {"a", "b", "e", "f"};
905  general_options.m_spectators = {"missing branch"};
906  try {
907  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
908  } catch (...) {
909 
910  }
911  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
912 
913  // Check for missing file
914  general_options.m_spectators = {};
915  general_options.m_datafiles = {"DOESNOTEXIST.root"};
916  general_options.m_treename = "tree";
917  try {
918  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
919  } catch (...) {
920 
921  }
922  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
923 
924  // Check for invalid file
925  general_options.m_datafiles = {"ISNotAValidROOTFile"};
926  general_options.m_treename = "tree";
927 
928  {
929  std::ofstream(general_options.m_datafiles[0]);
930  }
931  EXPECT_TRUE(boost::filesystem::exists(general_options.m_datafiles[0]));
932 
933  try {
934  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
935  } catch (...) {
936 
937  }
938  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
939  }
940 
941  TEST(DatasetTest, ROOTMultiDataset)
942  {
943 
945  TFile file("datafile.root", "RECREATE");
946  file.cd();
947  TTree tree("tree", "TreeTitle");
948  double a, b, c, d, e, f, v, w = 0;
949  bool target;
950  tree.Branch("a", &a);
951  tree.Branch("b", &b);
952  tree.Branch("c", &c);
953  tree.Branch("d", &d);
954  tree.Branch("e__bo__bc", &e);
955  tree.Branch("f__bo__bc", &f);
956  tree.Branch("g", &target);
957  tree.Branch("__weight__", &c);
958  tree.Branch("v__bo__bc", &v);
959  tree.Branch("w", &w);
960 
961  for (unsigned int i = 0; i < 5; ++i) {
962  a = i + 1.0;
963  b = i + 1.1;
964  c = i + 1.2;
965  d = i + 1.3;
966  e = i + 1.4;
967  f = i + 1.5;
968  target = (i % 2 == 0);
969  w = i + 1.6;
970  v = i + 1.7;
971  tree.Fill();
972  }
973 
974  file.Write("tree");
975 
976  TFile file2("datafile2.root", "RECREATE");
977  file2.cd();
978  TTree tree2("tree", "TreeTitle");
979  tree2.Branch("a", &a);
980  tree2.Branch("b", &b);
981  tree2.Branch("c", &c);
982  tree2.Branch("d", &d);
983  tree2.Branch("e__bo__bc", &e);
984  tree2.Branch("f__bo__bc", &f);
985  tree2.Branch("g", &target);
986  tree2.Branch("__weight__", &c);
987  tree2.Branch("v__bo__bc", &v);
988  tree2.Branch("w", &w);
989 
990  for (unsigned int i = 0; i < 5; ++i) {
991  a = i + 1.0;
992  b = i + 1.1;
993  c = i + 1.2;
994  d = i + 1.3;
995  e = i + 1.4;
996  f = i + 1.5;
997  target = (i % 2 == 0);
998  w = i + 1.6;
999  v = i + 1.7;
1000  tree2.Fill();
1001  }
1002 
1003  file2.Write("tree");
1004 
1005  MVA::GeneralOptions general_options;
1006  // Both names with and without makeROOTCompatible should work
1007  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1008  general_options.m_spectators = {"w", "v()"};
1009  general_options.m_signal_class = 1;
1010  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1011  general_options.m_treename = "tree";
1012  general_options.m_target_variable = "g";
1013  general_options.m_weight_variable = "c";
1014  MVA::ROOTDataset x(general_options);
1015 
1016  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1017  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1018  EXPECT_EQ(x.getNumberOfEvents(), 10);
1019 
1020  // Should just work
1021  x.loadEvent(0);
1022  EXPECT_EQ(x.m_input.size(), 4);
1023  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1024  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1025  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1026  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1027  EXPECT_EQ(x.m_spectators.size(), 2);
1028  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1029  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1030  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1031  EXPECT_EQ(x.m_target, true);
1032  EXPECT_EQ(x.m_isSignal, true);
1033 
1034  x.loadEvent(5);
1035  EXPECT_EQ(x.m_input.size(), 4);
1036  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1037  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1038  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1039  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1040  EXPECT_EQ(x.m_spectators.size(), 2);
1041  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1042  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1043  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1044  EXPECT_EQ(x.m_target, true);
1045  EXPECT_EQ(x.m_isSignal, true);
1046 
1047  x.loadEvent(1);
1048  EXPECT_EQ(x.m_input.size(), 4);
1049  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1050  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1051  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1052  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1053  EXPECT_EQ(x.m_spectators.size(), 2);
1054  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1055  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1056  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1057  EXPECT_EQ(x.m_target, false);
1058  EXPECT_EQ(x.m_isSignal, false);
1059 
1060  x.loadEvent(6);
1061  EXPECT_EQ(x.m_input.size(), 4);
1062  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1063  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1064  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1065  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1066  EXPECT_EQ(x.m_spectators.size(), 2);
1067  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1068  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1069  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1070  EXPECT_EQ(x.m_target, false);
1071  EXPECT_EQ(x.m_isSignal, false);
1072 
1073  x.loadEvent(2);
1074  EXPECT_EQ(x.m_input.size(), 4);
1075  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1076  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1077  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1078  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1079  EXPECT_EQ(x.m_spectators.size(), 2);
1080  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1081  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1082  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1083  EXPECT_EQ(x.m_target, true);
1084  EXPECT_EQ(x.m_isSignal, true);
1085 
1086  x.loadEvent(7);
1087  EXPECT_EQ(x.m_input.size(), 4);
1088  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1089  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1090  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1091  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1092  EXPECT_EQ(x.m_spectators.size(), 2);
1093  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1094  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1095  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1096  EXPECT_EQ(x.m_target, true);
1097  EXPECT_EQ(x.m_isSignal, true);
1098 
1099  x.loadEvent(3);
1100  EXPECT_EQ(x.m_input.size(), 4);
1101  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1102  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1103  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1104  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1105  EXPECT_EQ(x.m_spectators.size(), 2);
1106  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1107  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1108  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1109  EXPECT_EQ(x.m_target, false);
1110  EXPECT_EQ(x.m_isSignal, false);
1111 
1112  x.loadEvent(8);
1113  EXPECT_EQ(x.m_input.size(), 4);
1114  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1115  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1116  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1117  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1118  EXPECT_EQ(x.m_spectators.size(), 2);
1119  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1120  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1121  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1122  EXPECT_EQ(x.m_target, false);
1123  EXPECT_EQ(x.m_isSignal, false);
1124 
1125  x.loadEvent(4);
1126  EXPECT_EQ(x.m_input.size(), 4);
1127  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1128  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1129  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1130  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1131  EXPECT_EQ(x.m_spectators.size(), 2);
1132  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1133  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1134  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1135  EXPECT_EQ(x.m_target, true);
1136  EXPECT_EQ(x.m_isSignal, true);
1137 
1138  x.loadEvent(9);
1139  EXPECT_EQ(x.m_input.size(), 4);
1140  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1141  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1142  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1143  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1144  EXPECT_EQ(x.m_spectators.size(), 2);
1145  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1146  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1147  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1148  EXPECT_EQ(x.m_target, true);
1149  EXPECT_EQ(x.m_isSignal, true);
1150 
1151  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1152 
1153  auto feature = x.getFeature(1);
1154  EXPECT_EQ(feature.size(), 10);
1155  EXPECT_FLOAT_EQ(feature[0], 1.1);
1156  EXPECT_FLOAT_EQ(feature[1], 2.1);
1157  EXPECT_FLOAT_EQ(feature[2], 3.1);
1158  EXPECT_FLOAT_EQ(feature[3], 4.1);
1159  EXPECT_FLOAT_EQ(feature[4], 5.1);
1160  EXPECT_FLOAT_EQ(feature[5], 1.1);
1161  EXPECT_FLOAT_EQ(feature[6], 2.1);
1162  EXPECT_FLOAT_EQ(feature[7], 3.1);
1163  EXPECT_FLOAT_EQ(feature[8], 4.1);
1164  EXPECT_FLOAT_EQ(feature[9], 5.1);
1165 
1166  // Same result for mother class implementation
1167  feature = x.Dataset::getFeature(1);
1168  EXPECT_EQ(feature.size(), 10);
1169  EXPECT_FLOAT_EQ(feature[0], 1.1);
1170  EXPECT_FLOAT_EQ(feature[1], 2.1);
1171  EXPECT_FLOAT_EQ(feature[2], 3.1);
1172  EXPECT_FLOAT_EQ(feature[3], 4.1);
1173  EXPECT_FLOAT_EQ(feature[4], 5.1);
1174  EXPECT_FLOAT_EQ(feature[5], 1.1);
1175  EXPECT_FLOAT_EQ(feature[6], 2.1);
1176  EXPECT_FLOAT_EQ(feature[7], 3.1);
1177  EXPECT_FLOAT_EQ(feature[8], 4.1);
1178  EXPECT_FLOAT_EQ(feature[9], 5.1);
1179 
1180  auto spectator = x.getSpectator(1);
1181  EXPECT_EQ(spectator.size(), 10);
1182  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1183  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1184  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1185  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1186  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1187  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1188  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1189  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1190  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1191  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1192 
1193  // Same result for mother class implementation
1194  spectator = x.Dataset::getSpectator(1);
1195  EXPECT_EQ(spectator.size(), 10);
1196  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1197  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1198  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1199  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1200  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1201  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1202  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1203  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1204  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1205  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1206 
1207  auto weights = x.getWeights();
1208  EXPECT_EQ(weights.size(), 10);
1209  EXPECT_FLOAT_EQ(weights[0], 1.2);
1210  EXPECT_FLOAT_EQ(weights[1], 2.2);
1211  EXPECT_FLOAT_EQ(weights[2], 3.2);
1212  EXPECT_FLOAT_EQ(weights[3], 4.2);
1213  EXPECT_FLOAT_EQ(weights[4], 5.2);
1214  EXPECT_FLOAT_EQ(weights[5], 1.2);
1215  EXPECT_FLOAT_EQ(weights[6], 2.2);
1216  EXPECT_FLOAT_EQ(weights[7], 3.2);
1217  EXPECT_FLOAT_EQ(weights[8], 4.2);
1218  EXPECT_FLOAT_EQ(weights[9], 5.2);
1219 
1220  // Same result for mother class implementation
1221  weights = x.Dataset::getWeights();
1222  EXPECT_EQ(weights.size(), 10);
1223  EXPECT_FLOAT_EQ(weights[0], 1.2);
1224  EXPECT_FLOAT_EQ(weights[1], 2.2);
1225  EXPECT_FLOAT_EQ(weights[2], 3.2);
1226  EXPECT_FLOAT_EQ(weights[3], 4.2);
1227  EXPECT_FLOAT_EQ(weights[4], 5.2);
1228  EXPECT_FLOAT_EQ(weights[5], 1.2);
1229  EXPECT_FLOAT_EQ(weights[6], 2.2);
1230  EXPECT_FLOAT_EQ(weights[7], 3.2);
1231  EXPECT_FLOAT_EQ(weights[8], 4.2);
1232  EXPECT_FLOAT_EQ(weights[9], 5.2);
1233 
1234  auto targets = x.getTargets();
1235  EXPECT_EQ(targets.size(), 10);
1236  EXPECT_EQ(targets[0], true);
1237  EXPECT_EQ(targets[1], false);
1238  EXPECT_EQ(targets[2], true);
1239  EXPECT_EQ(targets[3], false);
1240  EXPECT_EQ(targets[4], true);
1241  EXPECT_EQ(targets[5], true);
1242  EXPECT_EQ(targets[6], false);
1243  EXPECT_EQ(targets[7], true);
1244  EXPECT_EQ(targets[8], false);
1245  EXPECT_EQ(targets[9], true);
1246 
1247  auto signals = x.getSignals();
1248  EXPECT_EQ(signals.size(), 10);
1249  EXPECT_EQ(signals[0], true);
1250  EXPECT_EQ(signals[1], false);
1251  EXPECT_EQ(signals[2], true);
1252  EXPECT_EQ(signals[3], false);
1253  EXPECT_EQ(signals[4], true);
1254  EXPECT_EQ(signals[5], true);
1255  EXPECT_EQ(signals[6], false);
1256  EXPECT_EQ(signals[7], true);
1257  EXPECT_EQ(signals[8], false);
1258  EXPECT_EQ(signals[9], true);
1259 
1260  // Using __weight__ should work as well,
1261  // the only difference to using _weight__ instead of c is
1262  // in setBranchAddresses which avoids calling makeROOTCompatible
1263  // So we have to check the behaviour using __weight__ as well
1264  general_options.m_weight_variable = "__weight__";
1265  MVA::ROOTDataset y(general_options);
1266 
1267  weights = y.getWeights();
1268  EXPECT_EQ(weights.size(), 10);
1269  EXPECT_FLOAT_EQ(weights[0], 1.2);
1270  EXPECT_FLOAT_EQ(weights[1], 2.2);
1271  EXPECT_FLOAT_EQ(weights[2], 3.2);
1272  EXPECT_FLOAT_EQ(weights[3], 4.2);
1273  EXPECT_FLOAT_EQ(weights[4], 5.2);
1274  EXPECT_FLOAT_EQ(weights[5], 1.2);
1275  EXPECT_FLOAT_EQ(weights[6], 2.2);
1276  EXPECT_FLOAT_EQ(weights[7], 3.2);
1277  EXPECT_FLOAT_EQ(weights[8], 4.2);
1278  EXPECT_FLOAT_EQ(weights[9], 5.2);
1279 
1280  // Check TChain expansion
1281  general_options.m_datafiles = {"datafile*.root"};
1282  {
1283  MVA::ROOTDataset chain_test(general_options);
1284  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1285  }
1286  boost::filesystem::copy_file("datafile.root", "datafile3.root");
1287  {
1288  MVA::ROOTDataset chain_test(general_options);
1289  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1290  }
1291  boost::filesystem::copy_file("datafile.root", "datafile4.root");
1292  {
1293  MVA::ROOTDataset chain_test(general_options);
1294  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1295  }
1296  // Test m_max_events feature
1297  {
1298  general_options.m_max_events = 10;
1299  MVA::ROOTDataset chain_test(general_options);
1300  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1301  general_options.m_max_events = 0;
1302  }
1303 
1304  // If a file exists with the specified expansion
1305  // the file takes precedence over the expansion
1306  boost::filesystem::copy_file("datafile.root", "datafile*.root");
1307  {
1308  general_options.m_max_events = 0;
1309  MVA::ROOTDataset chain_test(general_options);
1310  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1311  }
1312 
1313  }
1314  TEST(DatasetTest, ROOTMultiDatasetDouble)
1315  {
1316 
1318  TFile file("datafile.root", "RECREATE");
1319  file.cd();
1320  TTree tree("tree", "TreeTitle");
1321  double a, b, c, d, e, f, g, v, w = 0;
1322  tree.Branch("a", &a, "a/D");
1323  tree.Branch("b", &b, "b/D");
1324  tree.Branch("c", &c, "c/D");
1325  tree.Branch("d", &d, "d/D");
1326  tree.Branch("e__bo__bc", &e, "e__bo__bc/D");
1327  tree.Branch("f__bo__bc", &f, "f__bo__bc/D");
1328  tree.Branch("g", &g, "g/D");
1329  tree.Branch("__weight__", &c, "__weight__/D");
1330  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
1331  tree.Branch("w", &w, "w/D");
1332 
1333  for (unsigned int i = 0; i < 5; ++i) {
1334  a = i + 1.0;
1335  b = i + 1.1;
1336  c = i + 1.2;
1337  d = i + 1.3;
1338  e = i + 1.4;
1339  f = i + 1.5;
1340  g = float(i % 2 == 0);
1341  w = i + 1.6;
1342  v = i + 1.7;
1343  tree.Fill();
1344  }
1345 
1346  file.Write("tree");
1347 
1348  TFile file2("datafile2.root", "RECREATE");
1349  file2.cd();
1350  TTree tree2("tree", "TreeTitle");
1351  tree2.Branch("a", &a);
1352  tree2.Branch("b", &b);
1353  tree2.Branch("c", &c);
1354  tree2.Branch("d", &d);
1355  tree2.Branch("e__bo__bc", &e);
1356  tree2.Branch("f__bo__bc", &f);
1357  tree2.Branch("g", &g);
1358  tree2.Branch("__weight__", &c);
1359  tree2.Branch("v__bo__bc", &v);
1360  tree2.Branch("w", &w);
1361 
1362  for (unsigned int i = 0; i < 5; ++i) {
1363  a = i + 1.0;
1364  b = i + 1.1;
1365  c = i + 1.2;
1366  d = i + 1.3;
1367  e = i + 1.4;
1368  f = i + 1.5;
1369  g = float(i % 2 == 0);
1370  w = i + 1.6;
1371  v = i + 1.7;
1372  tree2.Fill();
1373  }
1374 
1375  file2.Write("tree");
1376 
1377  MVA::GeneralOptions general_options;
1378  // Both names with and without makeROOTCompatible should work
1379  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1380  general_options.m_spectators = {"w", "v()"};
1381  general_options.m_signal_class = 1;
1382  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1383  general_options.m_treename = "tree";
1384  general_options.m_target_variable = "g";
1385  general_options.m_weight_variable = "c";
1386  MVA::ROOTDataset x(general_options);
1387 
1388  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1389  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1390  EXPECT_EQ(x.getNumberOfEvents(), 10);
1391 
1392  // Should just work
1393  x.loadEvent(0);
1394  EXPECT_EQ(x.m_input.size(), 4);
1395  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1396  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1397  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1398  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1399  EXPECT_EQ(x.m_spectators.size(), 2);
1400  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1401  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1402  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1403  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1404  EXPECT_EQ(x.m_isSignal, true);
1405 
1406  x.loadEvent(5);
1407  EXPECT_EQ(x.m_input.size(), 4);
1408  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1409  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1410  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1411  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1412  EXPECT_EQ(x.m_spectators.size(), 2);
1413  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1414  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1415  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1416  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1417  EXPECT_EQ(x.m_isSignal, true);
1418 
1419  x.loadEvent(1);
1420  EXPECT_EQ(x.m_input.size(), 4);
1421  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1422  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1423  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1424  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1425  EXPECT_EQ(x.m_spectators.size(), 2);
1426  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1427  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1428  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1429  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1430  EXPECT_EQ(x.m_isSignal, false);
1431 
1432  x.loadEvent(6);
1433  EXPECT_EQ(x.m_input.size(), 4);
1434  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1435  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1436  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1437  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1438  EXPECT_EQ(x.m_spectators.size(), 2);
1439  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1440  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1441  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1442  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1443  EXPECT_EQ(x.m_isSignal, false);
1444 
1445  x.loadEvent(2);
1446  EXPECT_EQ(x.m_input.size(), 4);
1447  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1448  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1449  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1450  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1451  EXPECT_EQ(x.m_spectators.size(), 2);
1452  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1453  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1454  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1455  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1456  EXPECT_EQ(x.m_isSignal, true);
1457 
1458  x.loadEvent(7);
1459  EXPECT_EQ(x.m_input.size(), 4);
1460  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1461  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1462  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1463  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1464  EXPECT_EQ(x.m_spectators.size(), 2);
1465  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1466  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1467  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1468  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1469  EXPECT_EQ(x.m_isSignal, true);
1470 
1471  x.loadEvent(3);
1472  EXPECT_EQ(x.m_input.size(), 4);
1473  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1474  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1475  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1476  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1477  EXPECT_EQ(x.m_spectators.size(), 2);
1478  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1479  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1480  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1481  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1482  EXPECT_EQ(x.m_isSignal, false);
1483 
1484  x.loadEvent(8);
1485  EXPECT_EQ(x.m_input.size(), 4);
1486  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1487  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1488  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1489  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1490  EXPECT_EQ(x.m_spectators.size(), 2);
1491  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1492  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1493  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1494  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1495  EXPECT_EQ(x.m_isSignal, false);
1496 
1497  x.loadEvent(4);
1498  EXPECT_EQ(x.m_input.size(), 4);
1499  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1500  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1501  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1502  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1503  EXPECT_EQ(x.m_spectators.size(), 2);
1504  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1505  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1506  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1507  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1508  EXPECT_EQ(x.m_isSignal, true);
1509 
1510  x.loadEvent(9);
1511  EXPECT_EQ(x.m_input.size(), 4);
1512  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1513  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1514  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1515  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1516  EXPECT_EQ(x.m_spectators.size(), 2);
1517  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1518  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1519  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1520  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1521  EXPECT_EQ(x.m_isSignal, true);
1522 
1523  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1524 
1525  auto feature = x.getFeature(1);
1526  EXPECT_EQ(feature.size(), 10);
1527  EXPECT_FLOAT_EQ(feature[0], 1.1);
1528  EXPECT_FLOAT_EQ(feature[1], 2.1);
1529  EXPECT_FLOAT_EQ(feature[2], 3.1);
1530  EXPECT_FLOAT_EQ(feature[3], 4.1);
1531  EXPECT_FLOAT_EQ(feature[4], 5.1);
1532  EXPECT_FLOAT_EQ(feature[5], 1.1);
1533  EXPECT_FLOAT_EQ(feature[6], 2.1);
1534  EXPECT_FLOAT_EQ(feature[7], 3.1);
1535  EXPECT_FLOAT_EQ(feature[8], 4.1);
1536  EXPECT_FLOAT_EQ(feature[9], 5.1);
1537 
1538  // Same result for mother class implementation
1539  feature = x.Dataset::getFeature(1);
1540  EXPECT_EQ(feature.size(), 10);
1541  EXPECT_FLOAT_EQ(feature[0], 1.1);
1542  EXPECT_FLOAT_EQ(feature[1], 2.1);
1543  EXPECT_FLOAT_EQ(feature[2], 3.1);
1544  EXPECT_FLOAT_EQ(feature[3], 4.1);
1545  EXPECT_FLOAT_EQ(feature[4], 5.1);
1546  EXPECT_FLOAT_EQ(feature[5], 1.1);
1547  EXPECT_FLOAT_EQ(feature[6], 2.1);
1548  EXPECT_FLOAT_EQ(feature[7], 3.1);
1549  EXPECT_FLOAT_EQ(feature[8], 4.1);
1550  EXPECT_FLOAT_EQ(feature[9], 5.1);
1551 
1552  auto spectator = x.getSpectator(1);
1553  EXPECT_EQ(spectator.size(), 10);
1554  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1555  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1556  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1557  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1558  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1559  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1560  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1561  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1562  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1563  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1564 
1565  // Same result for mother class implementation
1566  spectator = x.Dataset::getSpectator(1);
1567  EXPECT_EQ(spectator.size(), 10);
1568  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1569  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1570  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1571  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1572  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1573  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1574  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1575  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1576  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1577  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1578 
1579  auto weights = x.getWeights();
1580  EXPECT_EQ(weights.size(), 10);
1581  EXPECT_FLOAT_EQ(weights[0], 1.2);
1582  EXPECT_FLOAT_EQ(weights[1], 2.2);
1583  EXPECT_FLOAT_EQ(weights[2], 3.2);
1584  EXPECT_FLOAT_EQ(weights[3], 4.2);
1585  EXPECT_FLOAT_EQ(weights[4], 5.2);
1586  EXPECT_FLOAT_EQ(weights[5], 1.2);
1587  EXPECT_FLOAT_EQ(weights[6], 2.2);
1588  EXPECT_FLOAT_EQ(weights[7], 3.2);
1589  EXPECT_FLOAT_EQ(weights[8], 4.2);
1590  EXPECT_FLOAT_EQ(weights[9], 5.2);
1591 
1592  // Same result for mother class implementation
1593  weights = x.Dataset::getWeights();
1594  EXPECT_EQ(weights.size(), 10);
1595  EXPECT_FLOAT_EQ(weights[0], 1.2);
1596  EXPECT_FLOAT_EQ(weights[1], 2.2);
1597  EXPECT_FLOAT_EQ(weights[2], 3.2);
1598  EXPECT_FLOAT_EQ(weights[3], 4.2);
1599  EXPECT_FLOAT_EQ(weights[4], 5.2);
1600  EXPECT_FLOAT_EQ(weights[5], 1.2);
1601  EXPECT_FLOAT_EQ(weights[6], 2.2);
1602  EXPECT_FLOAT_EQ(weights[7], 3.2);
1603  EXPECT_FLOAT_EQ(weights[8], 4.2);
1604  EXPECT_FLOAT_EQ(weights[9], 5.2);
1605 
1606  auto targets = x.getTargets();
1607  EXPECT_EQ(targets.size(), 10);
1608  EXPECT_FLOAT_EQ(targets[0], 1.0);
1609  EXPECT_FLOAT_EQ(targets[1], 0.0);
1610  EXPECT_FLOAT_EQ(targets[2], 1.0);
1611  EXPECT_FLOAT_EQ(targets[3], 0.0);
1612  EXPECT_FLOAT_EQ(targets[4], 1.0);
1613  EXPECT_FLOAT_EQ(targets[5], 1.0);
1614  EXPECT_FLOAT_EQ(targets[6], 0.0);
1615  EXPECT_FLOAT_EQ(targets[7], 1.0);
1616  EXPECT_FLOAT_EQ(targets[8], 0.0);
1617  EXPECT_FLOAT_EQ(targets[9], 1.0);
1618 
1619  auto signals = x.getSignals();
1620  EXPECT_EQ(signals.size(), 10);
1621  EXPECT_EQ(signals[0], true);
1622  EXPECT_EQ(signals[1], false);
1623  EXPECT_EQ(signals[2], true);
1624  EXPECT_EQ(signals[3], false);
1625  EXPECT_EQ(signals[4], true);
1626  EXPECT_EQ(signals[5], true);
1627  EXPECT_EQ(signals[6], false);
1628  EXPECT_EQ(signals[7], true);
1629  EXPECT_EQ(signals[8], false);
1630  EXPECT_EQ(signals[9], true);
1631 
1632  // Using __weight__ should work as well,
1633  // the only difference to using _weight__ instead of g is
1634  // in setBranchAddresses which avoids calling makeROOTCompatible
1635  // So we have to check the behaviour using __weight__ as well
1636  general_options.m_weight_variable = "__weight__";
1637  MVA::ROOTDataset y(general_options);
1638 
1639  weights = y.getWeights();
1640  EXPECT_EQ(weights.size(), 10);
1641  EXPECT_FLOAT_EQ(weights[0], 1.2);
1642  EXPECT_FLOAT_EQ(weights[1], 2.2);
1643  EXPECT_FLOAT_EQ(weights[2], 3.2);
1644  EXPECT_FLOAT_EQ(weights[3], 4.2);
1645  EXPECT_FLOAT_EQ(weights[4], 5.2);
1646  EXPECT_FLOAT_EQ(weights[5], 1.2);
1647  EXPECT_FLOAT_EQ(weights[6], 2.2);
1648  EXPECT_FLOAT_EQ(weights[7], 3.2);
1649  EXPECT_FLOAT_EQ(weights[8], 4.2);
1650  EXPECT_FLOAT_EQ(weights[9], 5.2);
1651 
1652  // Check TChain expansion
1653  general_options.m_datafiles = {"datafile*.root"};
1654  {
1655  MVA::ROOTDataset chain_test(general_options);
1656  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1657  }
1658  boost::filesystem::copy_file("datafile.root", "datafile3.root");
1659  {
1660  MVA::ROOTDataset chain_test(general_options);
1661  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1662  }
1663  boost::filesystem::copy_file("datafile.root", "datafile4.root");
1664  {
1665  MVA::ROOTDataset chain_test(general_options);
1666  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1667  }
1668  // Test m_max_events feature
1669  {
1670  general_options.m_max_events = 10;
1671  MVA::ROOTDataset chain_test(general_options);
1672  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1673  general_options.m_max_events = 0;
1674  }
1675 
1676  // If a file exists with the specified expansion
1677  // the file takes precedence over the expansion
1678  boost::filesystem::copy_file("datafile.root", "datafile*.root");
1679  {
1680  general_options.m_max_events = 0;
1681  MVA::ROOTDataset chain_test(general_options);
1682  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1683  }
1684  }
1685 }
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Definition: Dataset.h:294
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
Definition: Dataset.h:349
Wraps the data of a single event into a Dataset.
Definition: Dataset.h:135
Wraps another Dataset and provides a view to a subset of its features and events.
Definition: Dataset.h:234
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:66
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:23