Belle II Software  release-05-01-25
test_Dataset.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #include <mva/interface/Dataset.h>
12 #include <framework/utilities/TestHelpers.h>
13 
14 #include <boost/filesystem/operations.hpp>
15 
16 #include <gtest/gtest.h>
17 
18 #include <fstream>
19 #include <numeric>
20 
21 using namespace Belle2;
22 
23 namespace {
24 
25  class TestDataset : public MVA::Dataset {
26  public:
27  explicit TestDataset(MVA::GeneralOptions& general_options) : MVA::Dataset(general_options)
28  {
29  m_input = {1.0, 2.0, 3.0, 4.0, 5.0};
30  m_target = 3.0;
31  m_isSignal = true;
32  m_weight = -3.0;
33  }
34 
35  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 5; }
36  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 2; }
37  [[nodiscard]] unsigned int getNumberOfEvents() const override { return 20; }
38  void loadEvent(unsigned int iEvent) override
39  {
40  auto f = static_cast<float>(iEvent);
41  m_input = {f + 1, f + 2, f + 3, f + 4, f + 5};
42  m_spectators = {f + 6, f + 7};
43  };
44  float getSignalFraction() override { return 0.1; };
45  std::vector<float> getFeature(unsigned int iFeature) override
46  {
47  std::vector<float> a(20, 0.0);
48  std::iota(a.begin(), a.end(), iFeature + 1);
49  return a;
50  }
51  std::vector<float> getSpectator(unsigned int iSpectator) override
52  {
53  std::vector<float> a(20, 0.0);
54  std::iota(a.begin(), a.end(), iSpectator + 6);
55  return a;
56  }
57 
58  };
59 
60  TEST(DatasetTest, SingleDataset)
61  {
62 
63  MVA::GeneralOptions general_options;
64  general_options.m_variables = {"a", "b", "c"};
65  general_options.m_spectators = {"e", "f"};
66  general_options.m_signal_class = 2;
67  std::vector<float> input = {1.0, 2.0, 3.0};
68  std::vector<float> spectators = {4.0, 5.0};
69 
70  MVA::SingleDataset x(general_options, input, 2.0, spectators);
71 
72  EXPECT_EQ(x.getNumberOfFeatures(), 3);
73  EXPECT_EQ(x.getNumberOfSpectators(), 2);
74  EXPECT_EQ(x.getNumberOfEvents(), 1);
75 
76  EXPECT_EQ(x.getFeatureIndex("a"), 0);
77  EXPECT_EQ(x.getFeatureIndex("b"), 1);
78  EXPECT_EQ(x.getFeatureIndex("c"), 2);
79  EXPECT_B2ERROR(x.getFeatureIndex("bla"));
80 
81  EXPECT_EQ(x.getSpectatorIndex("e"), 0);
82  EXPECT_EQ(x.getSpectatorIndex("f"), 1);
83  EXPECT_B2ERROR(x.getSpectatorIndex("bla"));
84 
85  // Should just work
86  x.loadEvent(0);
87 
88  EXPECT_EQ(x.m_input.size(), 3);
89  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
90  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
91  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
92 
93  EXPECT_EQ(x.m_spectators.size(), 2);
94  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.0);
95  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.0);
96 
97  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
98  EXPECT_FLOAT_EQ(x.m_target, 2.0);
99  EXPECT_EQ(x.m_isSignal, true);
100 
101  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1.0);
102 
103  auto feature = x.getFeature(1);
104  EXPECT_EQ(feature.size(), 1);
105  EXPECT_FLOAT_EQ(feature[0], 2.0);
106 
107  auto spectator = x.getSpectator(1);
108  EXPECT_EQ(spectator.size(), 1);
109  EXPECT_FLOAT_EQ(spectator[0], 5.0);
110 
111  // Same result for mother class implementation
112  feature = x.Dataset::getFeature(1);
113  EXPECT_EQ(feature.size(), 1);
114  EXPECT_FLOAT_EQ(feature[0], 2.0);
115 
116  }
117 
118  TEST(DatasetTest, MultiDataset)
119  {
120 
121  MVA::GeneralOptions general_options;
122  general_options.m_variables = {"a", "b", "c"};
123  general_options.m_spectators = {"e", "f"};
124  general_options.m_signal_class = 2;
125  std::vector<std::vector<float>> matrix = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}};
126  std::vector<std::vector<float>> spectator_matrix = {{12.0, 13.0}, {15.0, 16.0}, {18.0, 19.0}};
127  std::vector<float> targets = {2.0, 0.0, 2.0};
128  std::vector<float> weights = {1.0, 2.0, 3.0};
129 
130  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, {1.0}, weights));
131 
132  EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, targets, {1.0}));
133 
134  MVA::MultiDataset x(general_options, matrix, spectator_matrix, targets, weights);
135 
136  EXPECT_EQ(x.getNumberOfFeatures(), 3);
137  EXPECT_EQ(x.getNumberOfEvents(), 3);
138 
139  // Should just work
140  x.loadEvent(0);
141 
142  EXPECT_EQ(x.m_input.size(), 3);
143  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
144  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
145  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
146 
147  EXPECT_EQ(x.m_spectators.size(), 2);
148  EXPECT_FLOAT_EQ(x.m_spectators[0], 12.0);
149  EXPECT_FLOAT_EQ(x.m_spectators[1], 13.0);
150 
151  EXPECT_FLOAT_EQ(x.m_weight, 1.0);
152  EXPECT_FLOAT_EQ(x.m_target, 2.0);
153  EXPECT_EQ(x.m_isSignal, true);
154 
155  // Should just work
156  x.loadEvent(1);
157 
158  EXPECT_EQ(x.m_input.size(), 3);
159  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
160  EXPECT_FLOAT_EQ(x.m_input[1], 5.0);
161  EXPECT_FLOAT_EQ(x.m_input[2], 6.0);
162 
163  EXPECT_EQ(x.m_spectators.size(), 2);
164  EXPECT_FLOAT_EQ(x.m_spectators[0], 15.0);
165  EXPECT_FLOAT_EQ(x.m_spectators[1], 16.0);
166 
167  EXPECT_FLOAT_EQ(x.m_weight, 2.0);
168  EXPECT_FLOAT_EQ(x.m_target, 0.0);
169  EXPECT_EQ(x.m_isSignal, false);
170 
171  // Should just work
172  x.loadEvent(2);
173 
174  EXPECT_EQ(x.m_input.size(), 3);
175  EXPECT_FLOAT_EQ(x.m_input[0], 7.0);
176  EXPECT_FLOAT_EQ(x.m_input[1], 8.0);
177  EXPECT_FLOAT_EQ(x.m_input[2], 9.0);
178 
179  EXPECT_EQ(x.m_spectators.size(), 2);
180  EXPECT_FLOAT_EQ(x.m_spectators[0], 18.0);
181  EXPECT_FLOAT_EQ(x.m_spectators[1], 19.0);
182 
183  EXPECT_FLOAT_EQ(x.m_weight, 3.0);
184  EXPECT_FLOAT_EQ(x.m_target, 2.0);
185  EXPECT_EQ(x.m_isSignal, true);
186 
187  EXPECT_FLOAT_EQ(x.getSignalFraction(), 4.0 / 6.0);
188 
189  auto feature = x.getFeature(1);
190  EXPECT_EQ(feature.size(), 3);
191  EXPECT_FLOAT_EQ(feature[0], 2.0);
192  EXPECT_FLOAT_EQ(feature[1], 5.0);
193  EXPECT_FLOAT_EQ(feature[2], 8.0);
194 
195  auto spectator = x.getSpectator(1);
196  EXPECT_EQ(spectator.size(), 3);
197  EXPECT_FLOAT_EQ(spectator[0], 13.0);
198  EXPECT_FLOAT_EQ(spectator[1], 16.0);
199  EXPECT_FLOAT_EQ(spectator[2], 19.0);
200 
201 
202  }
203 
204  TEST(DatasetTest, SubDataset)
205  {
206 
207  MVA::GeneralOptions general_options;
208  general_options.m_signal_class = 3;
209  general_options.m_variables = {"a", "b", "c", "d", "e"};
210  general_options.m_spectators = {"f", "g"};
211  TestDataset test_dataset(general_options);
212 
213  general_options.m_variables = {"a", "d", "e"};
214  general_options.m_spectators = {"g"};
215  std::vector<bool> events = {true, false, true, false, true, false, true, false, true, false,
216  true, false, true, false, true, false, true, false, true, false
217  };
218  MVA::SubDataset x(general_options, events, test_dataset);
219 
220  EXPECT_EQ(x.getNumberOfFeatures(), 3);
221  EXPECT_EQ(x.getNumberOfEvents(), 10);
222 
223  // Should just work
224  x.loadEvent(0);
225 
226  EXPECT_EQ(x.m_input.size(), 3);
227  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
228  EXPECT_FLOAT_EQ(x.m_input[1], 4.0);
229  EXPECT_FLOAT_EQ(x.m_input[2], 5.0);
230 
231  EXPECT_EQ(x.m_spectators.size(), 1);
232  EXPECT_FLOAT_EQ(x.m_spectators[0], 7.0);
233 
234  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
235  EXPECT_FLOAT_EQ(x.m_target, 3.0);
236  EXPECT_EQ(x.m_isSignal, true);
237 
238  EXPECT_FLOAT_EQ(x.getSignalFraction(), 1);
239 
240  auto feature = x.getFeature(1);
241  EXPECT_EQ(feature.size(), 10);
242  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
243  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
244  };
245 
246  auto spectator = x.getSpectator(0);
247  EXPECT_EQ(spectator.size(), 10);
248  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
249  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
250  };
251 
252  // Same result for mother class implementation
253  feature = x.Dataset::getFeature(1);
254  EXPECT_EQ(feature.size(), 10);
255  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
256  EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
257  };
258 
259  spectator = x.Dataset::getSpectator(0);
260  EXPECT_EQ(spectator.size(), 10);
261  for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
262  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
263  };
264 
265  // Test without event indices
266  MVA::SubDataset y(general_options, {}, test_dataset);
267  feature = y.getFeature(1);
268  EXPECT_EQ(feature.size(), 20);
269  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
270  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
271  };
272 
273  spectator = y.getSpectator(0);
274  EXPECT_EQ(spectator.size(), 20);
275  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
276  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
277  };
278 
279  // Same result for mother class implementation
280  feature = y.Dataset::getFeature(1);
281  EXPECT_EQ(feature.size(), 20);
282  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
283  EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
284  };
285 
286  spectator = y.Dataset::getSpectator(0);
287  EXPECT_EQ(spectator.size(), 20);
288  for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
289  EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
290  };
291 
292  general_options.m_variables = {"a", "d", "e", "DOESNOTEXIST"};
293  try {
294  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
295  } catch (...) {
296 
297  }
298  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
299 
300  general_options.m_variables = {"a", "d", "e"};
301  general_options.m_spectators = {"DOESNOTEXIST"};
302  try {
303  EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
304  } catch (...) {
305 
306  }
307  EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
308 
309  }
310 
311  TEST(DatasetTest, CombinedDataset)
312  {
313 
314  MVA::GeneralOptions general_options;
315  general_options.m_signal_class = 1;
316  general_options.m_variables = {"a", "b", "c", "d", "e"};
317  general_options.m_spectators = {"f", "g"};
318  TestDataset signal_dataset(general_options);
319  TestDataset bckgrd_dataset(general_options);
320 
321  MVA::CombinedDataset x(general_options, signal_dataset, bckgrd_dataset);
322 
323  EXPECT_EQ(x.getNumberOfFeatures(), 5);
324  EXPECT_EQ(x.getNumberOfEvents(), 40);
325 
326  // Should just work
327  x.loadEvent(0);
328 
329  EXPECT_EQ(x.m_input.size(), 5);
330  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
331  EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
332  EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
333  EXPECT_FLOAT_EQ(x.m_input[3], 4.0);
334  EXPECT_FLOAT_EQ(x.m_input[4], 5.0);
335 
336  EXPECT_EQ(x.m_spectators.size(), 2);
337  EXPECT_FLOAT_EQ(x.m_spectators[0], 6.0);
338  EXPECT_FLOAT_EQ(x.m_spectators[1], 7.0);
339 
340  EXPECT_FLOAT_EQ(x.m_weight, -3.0);
341  EXPECT_FLOAT_EQ(x.m_target, 1.0);
342  EXPECT_EQ(x.m_isSignal, true);
343 
344  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.5);
345 
346  auto feature = x.getFeature(1);
347  EXPECT_EQ(feature.size(), 40);
348  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
349  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
350  };
351 
352  auto spectator = x.getSpectator(0);
353  EXPECT_EQ(spectator.size(), 40);
354  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
355  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
356  };
357 
358  // Same result for mother class implementation
359  feature = x.Dataset::getFeature(1);
360  EXPECT_EQ(feature.size(), 40);
361  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
362  EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
363  };
364 
365  spectator = x.Dataset::getSpectator(0);
366  EXPECT_EQ(spectator.size(), 40);
367  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
368  EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
369  };
370 
371  for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
372  x.loadEvent(iEvent);
373  EXPECT_EQ(x.m_isSignal, iEvent < 20);
374  EXPECT_FLOAT_EQ(x.m_target, iEvent < 20 ? 1.0 : 0.0);
375  }
376 
377  }
378 
379  TEST(DatasetTest, ROOTDataset)
380  {
381 
383  TFile file("datafile.root", "RECREATE");
384  file.cd();
385  TTree tree("tree", "TreeTitle");
386  float a, b, c, d, e, f, g, v, w = 0;
387  tree.Branch("a", &a);
388  tree.Branch("b", &b);
389  tree.Branch("c", &c);
390  tree.Branch("d", &d);
391  tree.Branch("e__bo__bc", &e);
392  tree.Branch("f__bo__bc", &f);
393  tree.Branch("g", &g);
394  tree.Branch("__weight__", &c);
395  tree.Branch("v__bo__bc", &v);
396  tree.Branch("w", &w);
397 
398  for (unsigned int i = 0; i < 5; ++i) {
399  a = i + 1.0;
400  b = i + 1.1;
401  c = i + 1.2;
402  d = i + 1.3;
403  e = i + 1.4;
404  f = i + 1.5;
405  g = float(i % 2 == 0);
406  w = i + 1.6;
407  v = i + 1.7;
408  tree.Fill();
409  }
410 
411  file.Write("tree");
412 
413  MVA::GeneralOptions general_options;
414  // Both names with and without makeROOTCompatible should work
415  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
416  general_options.m_spectators = {"w", "v()"};
417  general_options.m_signal_class = 1;
418  general_options.m_datafiles = {"datafile.root"};
419  general_options.m_treename = "tree";
420  general_options.m_target_variable = "g";
421  general_options.m_weight_variable = "c";
422  MVA::ROOTDataset x(general_options);
423 
424  EXPECT_EQ(x.getNumberOfFeatures(), 4);
425  EXPECT_EQ(x.getNumberOfSpectators(), 2);
426  EXPECT_EQ(x.getNumberOfEvents(), 5);
427 
428  // Should just work
429  x.loadEvent(0);
430  EXPECT_EQ(x.m_input.size(), 4);
431  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
432  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
433  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
434  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
435  EXPECT_EQ(x.m_spectators.size(), 2);
436  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
437  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
438  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
439  EXPECT_FLOAT_EQ(x.m_target, 1.0);
440  EXPECT_EQ(x.m_isSignal, true);
441 
442  x.loadEvent(1);
443  EXPECT_EQ(x.m_input.size(), 4);
444  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
445  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
446  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
447  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
448  EXPECT_EQ(x.m_spectators.size(), 2);
449  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
450  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
451  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
452  EXPECT_FLOAT_EQ(x.m_target, 0.0);
453  EXPECT_EQ(x.m_isSignal, false);
454 
455  x.loadEvent(2);
456  EXPECT_EQ(x.m_input.size(), 4);
457  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
458  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
459  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
460  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
461  EXPECT_EQ(x.m_spectators.size(), 2);
462  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
463  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
464  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
465  EXPECT_FLOAT_EQ(x.m_target, 1.0);
466  EXPECT_EQ(x.m_isSignal, true);
467 
468  x.loadEvent(3);
469  EXPECT_EQ(x.m_input.size(), 4);
470  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
471  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
472  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
473  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
474  EXPECT_EQ(x.m_spectators.size(), 2);
475  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
476  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
477  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
478  EXPECT_FLOAT_EQ(x.m_target, 0.0);
479  EXPECT_EQ(x.m_isSignal, false);
480 
481  x.loadEvent(4);
482  EXPECT_EQ(x.m_input.size(), 4);
483  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
484  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
485  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
486  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
487  EXPECT_EQ(x.m_spectators.size(), 2);
488  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
489  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
490  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
491  EXPECT_FLOAT_EQ(x.m_target, 1.0);
492  EXPECT_EQ(x.m_isSignal, true);
493 
494  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
495 
496  auto feature = x.getFeature(1);
497  EXPECT_EQ(feature.size(), 5);
498  EXPECT_FLOAT_EQ(feature[0], 1.1);
499  EXPECT_FLOAT_EQ(feature[1], 2.1);
500  EXPECT_FLOAT_EQ(feature[2], 3.1);
501  EXPECT_FLOAT_EQ(feature[3], 4.1);
502  EXPECT_FLOAT_EQ(feature[4], 5.1);
503 
504  // Same result for mother class implementation
505  feature = x.Dataset::getFeature(1);
506  EXPECT_EQ(feature.size(), 5);
507  EXPECT_FLOAT_EQ(feature[0], 1.1);
508  EXPECT_FLOAT_EQ(feature[1], 2.1);
509  EXPECT_FLOAT_EQ(feature[2], 3.1);
510  EXPECT_FLOAT_EQ(feature[3], 4.1);
511  EXPECT_FLOAT_EQ(feature[4], 5.1);
512 
513  auto spectator = x.getSpectator(1);
514  EXPECT_EQ(spectator.size(), 5);
515  EXPECT_FLOAT_EQ(spectator[0], 1.7);
516  EXPECT_FLOAT_EQ(spectator[1], 2.7);
517  EXPECT_FLOAT_EQ(spectator[2], 3.7);
518  EXPECT_FLOAT_EQ(spectator[3], 4.7);
519  EXPECT_FLOAT_EQ(spectator[4], 5.7);
520 
521  // Same result for mother class implementation
522  spectator = x.Dataset::getSpectator(1);
523  EXPECT_EQ(spectator.size(), 5);
524  EXPECT_FLOAT_EQ(spectator[0], 1.7);
525  EXPECT_FLOAT_EQ(spectator[1], 2.7);
526  EXPECT_FLOAT_EQ(spectator[2], 3.7);
527  EXPECT_FLOAT_EQ(spectator[3], 4.7);
528  EXPECT_FLOAT_EQ(spectator[4], 5.7);
529 
530  auto weights = x.getWeights();
531  EXPECT_EQ(weights.size(), 5);
532  EXPECT_FLOAT_EQ(weights[0], 1.2);
533  EXPECT_FLOAT_EQ(weights[1], 2.2);
534  EXPECT_FLOAT_EQ(weights[2], 3.2);
535  EXPECT_FLOAT_EQ(weights[3], 4.2);
536  EXPECT_FLOAT_EQ(weights[4], 5.2);
537 
538  // Same result for mother class implementation
539  weights = x.Dataset::getWeights();
540  EXPECT_EQ(weights.size(), 5);
541  EXPECT_FLOAT_EQ(weights[0], 1.2);
542  EXPECT_FLOAT_EQ(weights[1], 2.2);
543  EXPECT_FLOAT_EQ(weights[2], 3.2);
544  EXPECT_FLOAT_EQ(weights[3], 4.2);
545  EXPECT_FLOAT_EQ(weights[4], 5.2);
546 
547  auto targets = x.getTargets();
548  EXPECT_EQ(targets.size(), 5);
549  EXPECT_FLOAT_EQ(targets[0], 1.0);
550  EXPECT_FLOAT_EQ(targets[1], 0.0);
551  EXPECT_FLOAT_EQ(targets[2], 1.0);
552  EXPECT_FLOAT_EQ(targets[3], 0.0);
553  EXPECT_FLOAT_EQ(targets[4], 1.0);
554 
555  auto signals = x.getSignals();
556  EXPECT_EQ(signals.size(), 5);
557  EXPECT_EQ(signals[0], true);
558  EXPECT_EQ(signals[1], false);
559  EXPECT_EQ(signals[2], true);
560  EXPECT_EQ(signals[3], false);
561  EXPECT_EQ(signals[4], true);
562 
563  // Using __weight__ should work as well,
564  // the only difference to using _weight__ instead of g is
565  // in setBranchAddresses which avoids calling makeROOTCompatible
566  // So we have to check the behaviour using __weight__ as well
567  general_options.m_weight_variable = "__weight__";
568  MVA::ROOTDataset y(general_options);
569 
570  weights = y.getWeights();
571  EXPECT_EQ(weights.size(), 5);
572  EXPECT_FLOAT_EQ(weights[0], 1.2);
573  EXPECT_FLOAT_EQ(weights[1], 2.2);
574  EXPECT_FLOAT_EQ(weights[2], 3.2);
575  EXPECT_FLOAT_EQ(weights[3], 4.2);
576  EXPECT_FLOAT_EQ(weights[4], 5.2);
577 
578  // Check TChain expansion
579  general_options.m_datafiles = {"datafile*.root"};
580  {
581  MVA::ROOTDataset chain_test(general_options);
582  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
583  }
584  boost::filesystem::copy_file("datafile.root", "datafile2.root");
585  {
586  MVA::ROOTDataset chain_test(general_options);
587  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
588  }
589  boost::filesystem::copy_file("datafile.root", "datafile3.root");
590  {
591  MVA::ROOTDataset chain_test(general_options);
592  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
593  }
594  // Test m_max_events feature
595  {
596  general_options.m_max_events = 10;
597  MVA::ROOTDataset chain_test(general_options);
598  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
599  general_options.m_max_events = 0;
600  }
601 
602  // Check for missing tree
603  general_options.m_treename = "missing tree";
604  try {
605  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
606  } catch (...) {
607 
608  }
609  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
610 
611  // Check for missing branch
612  general_options.m_treename = "tree";
613  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
614  try {
615  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
616  } catch (...) {
617 
618  }
619  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
620 
621  // Check for missing branch
622  general_options.m_treename = "tree";
623  general_options.m_variables = {"a", "b", "e", "f"};
624  general_options.m_spectators = {"missing branch"};
625  try {
626  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
627  } catch (...) {
628 
629  }
630  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
631 
632  // Check for missing file
633  general_options.m_spectators = {};
634  general_options.m_datafiles = {"DOESNOTEXIST.root"};
635  general_options.m_treename = "tree";
636  try {
637  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
638  } catch (...) {
639 
640  }
641  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
642 
643  // Check for invalid file
644  general_options.m_datafiles = {"ISNotAValidROOTFile"};
645  general_options.m_treename = "tree";
646 
647  {
648  std::ofstream(general_options.m_datafiles[0]);
649  }
650  EXPECT_TRUE(boost::filesystem::exists(general_options.m_datafiles[0]));
651 
652  try {
653  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
654  } catch (...) {
655 
656  }
657  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
658  }
659 
660 
661  TEST(DatasetTest, ROOTDatasetDouble)
662  {
663 
665  TFile file("datafile.root", "RECREATE");
666  file.cd();
667  TTree tree("tree", "TreeTitle");
668  double a, b, c, d, e, f, g, v, w = 0;
669  tree.Branch("a", &a, "a/D");
670  tree.Branch("b", &b, "b/D");
671  tree.Branch("c", &c, "c/D");
672  tree.Branch("d", &d, "d/D");
673  tree.Branch("e__bo__bc", &e, "e__bo__bc/D");
674  tree.Branch("f__bo__bc", &f, "f__bo__bc/D");
675  tree.Branch("g", &g, "g/D");
676  tree.Branch("__weight__", &c, "__weight__/D");
677  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
678  tree.Branch("w", &w, "w/D");
679 
680  for (unsigned int i = 0; i < 5; ++i) {
681  a = i + 1.0;
682  b = i + 1.1;
683  c = i + 1.2;
684  d = i + 1.3;
685  e = i + 1.4;
686  f = i + 1.5;
687  g = float(i % 2 == 0);
688  w = i + 1.6;
689  v = i + 1.7;
690  tree.Fill();
691  }
692 
693  file.Write("tree");
694 
695  MVA::GeneralOptions general_options;
696  // Both names with and without makeROOTCompatible should work
697  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
698  general_options.m_spectators = {"w", "v()"};
699  general_options.m_signal_class = 1;
700  general_options.m_datafiles = {"datafile.root"};
701  general_options.m_treename = "tree";
702  general_options.m_target_variable = "g";
703  general_options.m_weight_variable = "c";
704  MVA::ROOTDataset x(general_options);
705 
706  EXPECT_EQ(x.getNumberOfFeatures(), 4);
707  EXPECT_EQ(x.getNumberOfSpectators(), 2);
708  EXPECT_EQ(x.getNumberOfEvents(), 5);
709 
710  // Should just work
711  x.loadEvent(0);
712  EXPECT_EQ(x.m_input.size(), 4);
713  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
714  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
715  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
716  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
717  EXPECT_EQ(x.m_spectators.size(), 2);
718  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
719  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
720  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
721  EXPECT_FLOAT_EQ(x.m_target, 1.0);
722  EXPECT_EQ(x.m_isSignal, true);
723 
724  x.loadEvent(1);
725  EXPECT_EQ(x.m_input.size(), 4);
726  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
727  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
728  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
729  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
730  EXPECT_EQ(x.m_spectators.size(), 2);
731  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
732  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
733  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
734  EXPECT_FLOAT_EQ(x.m_target, 0.0);
735  EXPECT_EQ(x.m_isSignal, false);
736 
737  x.loadEvent(2);
738  EXPECT_EQ(x.m_input.size(), 4);
739  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
740  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
741  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
742  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
743  EXPECT_EQ(x.m_spectators.size(), 2);
744  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
745  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
746  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
747  EXPECT_FLOAT_EQ(x.m_target, 1.0);
748  EXPECT_EQ(x.m_isSignal, true);
749 
750  x.loadEvent(3);
751  EXPECT_EQ(x.m_input.size(), 4);
752  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
753  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
754  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
755  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
756  EXPECT_EQ(x.m_spectators.size(), 2);
757  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
758  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
759  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
760  EXPECT_FLOAT_EQ(x.m_target, 0.0);
761  EXPECT_EQ(x.m_isSignal, false);
762 
763  x.loadEvent(4);
764  EXPECT_EQ(x.m_input.size(), 4);
765  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
766  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
767  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
768  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
769  EXPECT_EQ(x.m_spectators.size(), 2);
770  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
771  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
772  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
773  EXPECT_FLOAT_EQ(x.m_target, 1.0);
774  EXPECT_EQ(x.m_isSignal, true);
775 
776  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
777 
778  auto feature = x.getFeature(1);
779  EXPECT_EQ(feature.size(), 5);
780  EXPECT_FLOAT_EQ(feature[0], 1.1);
781  EXPECT_FLOAT_EQ(feature[1], 2.1);
782  EXPECT_FLOAT_EQ(feature[2], 3.1);
783  EXPECT_FLOAT_EQ(feature[3], 4.1);
784  EXPECT_FLOAT_EQ(feature[4], 5.1);
785 
786  // Same result for mother class implementation
787  feature = x.Dataset::getFeature(1);
788  EXPECT_EQ(feature.size(), 5);
789  EXPECT_FLOAT_EQ(feature[0], 1.1);
790  EXPECT_FLOAT_EQ(feature[1], 2.1);
791  EXPECT_FLOAT_EQ(feature[2], 3.1);
792  EXPECT_FLOAT_EQ(feature[3], 4.1);
793  EXPECT_FLOAT_EQ(feature[4], 5.1);
794 
795  auto spectator = x.getSpectator(1);
796  EXPECT_EQ(spectator.size(), 5);
797  EXPECT_FLOAT_EQ(spectator[0], 1.7);
798  EXPECT_FLOAT_EQ(spectator[1], 2.7);
799  EXPECT_FLOAT_EQ(spectator[2], 3.7);
800  EXPECT_FLOAT_EQ(spectator[3], 4.7);
801  EXPECT_FLOAT_EQ(spectator[4], 5.7);
802 
803  // Same result for mother class implementation
804  spectator = x.Dataset::getSpectator(1);
805  EXPECT_EQ(spectator.size(), 5);
806  EXPECT_FLOAT_EQ(spectator[0], 1.7);
807  EXPECT_FLOAT_EQ(spectator[1], 2.7);
808  EXPECT_FLOAT_EQ(spectator[2], 3.7);
809  EXPECT_FLOAT_EQ(spectator[3], 4.7);
810  EXPECT_FLOAT_EQ(spectator[4], 5.7);
811 
812  auto weights = x.getWeights();
813  EXPECT_EQ(weights.size(), 5);
814  EXPECT_FLOAT_EQ(weights[0], 1.2);
815  EXPECT_FLOAT_EQ(weights[1], 2.2);
816  EXPECT_FLOAT_EQ(weights[2], 3.2);
817  EXPECT_FLOAT_EQ(weights[3], 4.2);
818  EXPECT_FLOAT_EQ(weights[4], 5.2);
819 
820  // Same result for mother class implementation
821  weights = x.Dataset::getWeights();
822  EXPECT_EQ(weights.size(), 5);
823  EXPECT_FLOAT_EQ(weights[0], 1.2);
824  EXPECT_FLOAT_EQ(weights[1], 2.2);
825  EXPECT_FLOAT_EQ(weights[2], 3.2);
826  EXPECT_FLOAT_EQ(weights[3], 4.2);
827  EXPECT_FLOAT_EQ(weights[4], 5.2);
828 
829  auto targets = x.getTargets();
830  EXPECT_EQ(targets.size(), 5);
831  EXPECT_FLOAT_EQ(targets[0], 1.0);
832  EXPECT_FLOAT_EQ(targets[1], 0.0);
833  EXPECT_FLOAT_EQ(targets[2], 1.0);
834  EXPECT_FLOAT_EQ(targets[3], 0.0);
835  EXPECT_FLOAT_EQ(targets[4], 1.0);
836 
837  auto signals = x.getSignals();
838  EXPECT_EQ(signals.size(), 5);
839  EXPECT_EQ(signals[0], true);
840  EXPECT_EQ(signals[1], false);
841  EXPECT_EQ(signals[2], true);
842  EXPECT_EQ(signals[3], false);
843  EXPECT_EQ(signals[4], true);
844 
845  // Using __weight__ should work as well,
846  // the only difference to using _weight__ instead of g is
847  // in setBranchAddresses which avoids calling makeROOTCompatible
848  // So we have to check the behaviour using __weight__ as well
849  general_options.m_weight_variable = "__weight__";
850  MVA::ROOTDataset y(general_options);
851 
852  weights = y.getWeights();
853  EXPECT_EQ(weights.size(), 5);
854  EXPECT_FLOAT_EQ(weights[0], 1.2);
855  EXPECT_FLOAT_EQ(weights[1], 2.2);
856  EXPECT_FLOAT_EQ(weights[2], 3.2);
857  EXPECT_FLOAT_EQ(weights[3], 4.2);
858  EXPECT_FLOAT_EQ(weights[4], 5.2);
859 
860  // Check TChain expansion
861  general_options.m_datafiles = {"datafile*.root"};
862  {
863  MVA::ROOTDataset chain_test(general_options);
864  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
865  }
866  boost::filesystem::copy_file("datafile.root", "datafile2.root");
867  {
868  MVA::ROOTDataset chain_test(general_options);
869  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
870  }
871  boost::filesystem::copy_file("datafile.root", "datafile3.root");
872  {
873  MVA::ROOTDataset chain_test(general_options);
874  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
875  }
876  // Test m_max_events feature
877  {
878  general_options.m_max_events = 10;
879  MVA::ROOTDataset chain_test(general_options);
880  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
881  general_options.m_max_events = 0;
882  }
883 
884  // Check for missing tree
885  general_options.m_treename = "missing tree";
886  try {
887  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
888  } catch (...) {
889 
890  }
891  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
892 
893  // Check for missing branch
894  general_options.m_treename = "tree";
895  general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
896  try {
897  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
898  } catch (...) {
899 
900  }
901  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
902 
903  // Check for missing branch
904  general_options.m_treename = "tree";
905  general_options.m_variables = {"a", "b", "e", "f"};
906  general_options.m_spectators = {"missing branch"};
907  try {
908  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
909  } catch (...) {
910 
911  }
912  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
913 
914  // Check for missing file
915  general_options.m_spectators = {};
916  general_options.m_datafiles = {"DOESNOTEXIST.root"};
917  general_options.m_treename = "tree";
918  try {
919  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
920  } catch (...) {
921 
922  }
923  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
924 
925  // Check for invalid file
926  general_options.m_datafiles = {"ISNotAValidROOTFile"};
927  general_options.m_treename = "tree";
928 
929  {
930  std::ofstream(general_options.m_datafiles[0]);
931  }
932  EXPECT_TRUE(boost::filesystem::exists(general_options.m_datafiles[0]));
933 
934  try {
935  EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
936  } catch (...) {
937 
938  }
939  EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
940  }
941 
942  TEST(DatasetTest, ROOTMultiDataset)
943  {
944 
946  TFile file("datafile.root", "RECREATE");
947  file.cd();
948  TTree tree("tree", "TreeTitle");
949  float a, b, c, d, e, f, g, v, w = 0;
950  tree.Branch("a", &a);
951  tree.Branch("b", &b);
952  tree.Branch("c", &c);
953  tree.Branch("d", &d);
954  tree.Branch("e__bo__bc", &e);
955  tree.Branch("f__bo__bc", &f);
956  tree.Branch("g", &g);
957  tree.Branch("__weight__", &c);
958  tree.Branch("v__bo__bc", &v);
959  tree.Branch("w", &w);
960 
961  for (unsigned int i = 0; i < 5; ++i) {
962  a = i + 1.0;
963  b = i + 1.1;
964  c = i + 1.2;
965  d = i + 1.3;
966  e = i + 1.4;
967  f = i + 1.5;
968  g = float(i % 2 == 0);
969  w = i + 1.6;
970  v = i + 1.7;
971  tree.Fill();
972  }
973 
974  file.Write("tree");
975 
976  TFile file2("datafile2.root", "RECREATE");
977  file2.cd();
978  TTree tree2("tree", "TreeTitle");
979  tree2.Branch("a", &a);
980  tree2.Branch("b", &b);
981  tree2.Branch("c", &c);
982  tree2.Branch("d", &d);
983  tree2.Branch("e__bo__bc", &e);
984  tree2.Branch("f__bo__bc", &f);
985  tree2.Branch("g", &g);
986  tree2.Branch("__weight__", &c);
987  tree2.Branch("v__bo__bc", &v);
988  tree2.Branch("w", &w);
989 
990  for (unsigned int i = 0; i < 5; ++i) {
991  a = i + 1.0;
992  b = i + 1.1;
993  c = i + 1.2;
994  d = i + 1.3;
995  e = i + 1.4;
996  f = i + 1.5;
997  g = float(i % 2 == 0);
998  w = i + 1.6;
999  v = i + 1.7;
1000  tree2.Fill();
1001  }
1002 
1003  file2.Write("tree");
1004 
1005  MVA::GeneralOptions general_options;
1006  // Both names with and without makeROOTCompatible should work
1007  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1008  general_options.m_spectators = {"w", "v()"};
1009  general_options.m_signal_class = 1;
1010  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1011  general_options.m_treename = "tree";
1012  general_options.m_target_variable = "g";
1013  general_options.m_weight_variable = "c";
1014  MVA::ROOTDataset x(general_options);
1015 
1016  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1017  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1018  EXPECT_EQ(x.getNumberOfEvents(), 10);
1019 
1020  // Should just work
1021  x.loadEvent(0);
1022  EXPECT_EQ(x.m_input.size(), 4);
1023  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1024  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1025  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1026  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1027  EXPECT_EQ(x.m_spectators.size(), 2);
1028  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1029  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1030  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1031  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1032  EXPECT_EQ(x.m_isSignal, true);
1033 
1034  x.loadEvent(5);
1035  EXPECT_EQ(x.m_input.size(), 4);
1036  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1037  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1038  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1039  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1040  EXPECT_EQ(x.m_spectators.size(), 2);
1041  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1042  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1043  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1044  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1045  EXPECT_EQ(x.m_isSignal, true);
1046 
1047  x.loadEvent(1);
1048  EXPECT_EQ(x.m_input.size(), 4);
1049  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1050  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1051  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1052  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1053  EXPECT_EQ(x.m_spectators.size(), 2);
1054  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1055  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1056  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1057  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1058  EXPECT_EQ(x.m_isSignal, false);
1059 
1060  x.loadEvent(6);
1061  EXPECT_EQ(x.m_input.size(), 4);
1062  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1063  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1064  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1065  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1066  EXPECT_EQ(x.m_spectators.size(), 2);
1067  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1068  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1069  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1070  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1071  EXPECT_EQ(x.m_isSignal, false);
1072 
1073  x.loadEvent(2);
1074  EXPECT_EQ(x.m_input.size(), 4);
1075  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1076  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1077  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1078  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1079  EXPECT_EQ(x.m_spectators.size(), 2);
1080  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1081  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1082  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1083  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1084  EXPECT_EQ(x.m_isSignal, true);
1085 
1086  x.loadEvent(7);
1087  EXPECT_EQ(x.m_input.size(), 4);
1088  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1089  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1090  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1091  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1092  EXPECT_EQ(x.m_spectators.size(), 2);
1093  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1094  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1095  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1096  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1097  EXPECT_EQ(x.m_isSignal, true);
1098 
1099  x.loadEvent(3);
1100  EXPECT_EQ(x.m_input.size(), 4);
1101  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1102  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1103  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1104  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1105  EXPECT_EQ(x.m_spectators.size(), 2);
1106  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1107  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1108  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1109  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1110  EXPECT_EQ(x.m_isSignal, false);
1111 
1112  x.loadEvent(8);
1113  EXPECT_EQ(x.m_input.size(), 4);
1114  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1115  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1116  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1117  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1118  EXPECT_EQ(x.m_spectators.size(), 2);
1119  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1120  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1121  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1122  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1123  EXPECT_EQ(x.m_isSignal, false);
1124 
1125  x.loadEvent(4);
1126  EXPECT_EQ(x.m_input.size(), 4);
1127  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1128  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1129  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1130  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1131  EXPECT_EQ(x.m_spectators.size(), 2);
1132  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1133  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1134  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1135  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1136  EXPECT_EQ(x.m_isSignal, true);
1137 
1138  x.loadEvent(9);
1139  EXPECT_EQ(x.m_input.size(), 4);
1140  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1141  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1142  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1143  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1144  EXPECT_EQ(x.m_spectators.size(), 2);
1145  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1146  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1147  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1148  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1149  EXPECT_EQ(x.m_isSignal, true);
1150 
1151  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1152 
1153  auto feature = x.getFeature(1);
1154  EXPECT_EQ(feature.size(), 10);
1155  EXPECT_FLOAT_EQ(feature[0], 1.1);
1156  EXPECT_FLOAT_EQ(feature[1], 2.1);
1157  EXPECT_FLOAT_EQ(feature[2], 3.1);
1158  EXPECT_FLOAT_EQ(feature[3], 4.1);
1159  EXPECT_FLOAT_EQ(feature[4], 5.1);
1160  EXPECT_FLOAT_EQ(feature[5], 1.1);
1161  EXPECT_FLOAT_EQ(feature[6], 2.1);
1162  EXPECT_FLOAT_EQ(feature[7], 3.1);
1163  EXPECT_FLOAT_EQ(feature[8], 4.1);
1164  EXPECT_FLOAT_EQ(feature[9], 5.1);
1165 
1166  // Same result for mother class implementation
1167  feature = x.Dataset::getFeature(1);
1168  EXPECT_EQ(feature.size(), 10);
1169  EXPECT_FLOAT_EQ(feature[0], 1.1);
1170  EXPECT_FLOAT_EQ(feature[1], 2.1);
1171  EXPECT_FLOAT_EQ(feature[2], 3.1);
1172  EXPECT_FLOAT_EQ(feature[3], 4.1);
1173  EXPECT_FLOAT_EQ(feature[4], 5.1);
1174  EXPECT_FLOAT_EQ(feature[5], 1.1);
1175  EXPECT_FLOAT_EQ(feature[6], 2.1);
1176  EXPECT_FLOAT_EQ(feature[7], 3.1);
1177  EXPECT_FLOAT_EQ(feature[8], 4.1);
1178  EXPECT_FLOAT_EQ(feature[9], 5.1);
1179 
1180  auto spectator = x.getSpectator(1);
1181  EXPECT_EQ(spectator.size(), 10);
1182  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1183  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1184  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1185  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1186  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1187  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1188  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1189  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1190  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1191  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1192 
1193  // Same result for mother class implementation
1194  spectator = x.Dataset::getSpectator(1);
1195  EXPECT_EQ(spectator.size(), 10);
1196  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1197  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1198  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1199  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1200  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1201  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1202  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1203  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1204  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1205  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1206 
1207  auto weights = x.getWeights();
1208  EXPECT_EQ(weights.size(), 10);
1209  EXPECT_FLOAT_EQ(weights[0], 1.2);
1210  EXPECT_FLOAT_EQ(weights[1], 2.2);
1211  EXPECT_FLOAT_EQ(weights[2], 3.2);
1212  EXPECT_FLOAT_EQ(weights[3], 4.2);
1213  EXPECT_FLOAT_EQ(weights[4], 5.2);
1214  EXPECT_FLOAT_EQ(weights[5], 1.2);
1215  EXPECT_FLOAT_EQ(weights[6], 2.2);
1216  EXPECT_FLOAT_EQ(weights[7], 3.2);
1217  EXPECT_FLOAT_EQ(weights[8], 4.2);
1218  EXPECT_FLOAT_EQ(weights[9], 5.2);
1219 
1220  // Same result for mother class implementation
1221  weights = x.Dataset::getWeights();
1222  EXPECT_EQ(weights.size(), 10);
1223  EXPECT_FLOAT_EQ(weights[0], 1.2);
1224  EXPECT_FLOAT_EQ(weights[1], 2.2);
1225  EXPECT_FLOAT_EQ(weights[2], 3.2);
1226  EXPECT_FLOAT_EQ(weights[3], 4.2);
1227  EXPECT_FLOAT_EQ(weights[4], 5.2);
1228  EXPECT_FLOAT_EQ(weights[5], 1.2);
1229  EXPECT_FLOAT_EQ(weights[6], 2.2);
1230  EXPECT_FLOAT_EQ(weights[7], 3.2);
1231  EXPECT_FLOAT_EQ(weights[8], 4.2);
1232  EXPECT_FLOAT_EQ(weights[9], 5.2);
1233 
1234  auto targets = x.getTargets();
1235  EXPECT_EQ(targets.size(), 10);
1236  EXPECT_FLOAT_EQ(targets[0], 1.0);
1237  EXPECT_FLOAT_EQ(targets[1], 0.0);
1238  EXPECT_FLOAT_EQ(targets[2], 1.0);
1239  EXPECT_FLOAT_EQ(targets[3], 0.0);
1240  EXPECT_FLOAT_EQ(targets[4], 1.0);
1241  EXPECT_FLOAT_EQ(targets[5], 1.0);
1242  EXPECT_FLOAT_EQ(targets[6], 0.0);
1243  EXPECT_FLOAT_EQ(targets[7], 1.0);
1244  EXPECT_FLOAT_EQ(targets[8], 0.0);
1245  EXPECT_FLOAT_EQ(targets[9], 1.0);
1246 
1247  auto signals = x.getSignals();
1248  EXPECT_EQ(signals.size(), 10);
1249  EXPECT_EQ(signals[0], true);
1250  EXPECT_EQ(signals[1], false);
1251  EXPECT_EQ(signals[2], true);
1252  EXPECT_EQ(signals[3], false);
1253  EXPECT_EQ(signals[4], true);
1254  EXPECT_EQ(signals[5], true);
1255  EXPECT_EQ(signals[6], false);
1256  EXPECT_EQ(signals[7], true);
1257  EXPECT_EQ(signals[8], false);
1258  EXPECT_EQ(signals[9], true);
1259 
1260  // Using __weight__ should work as well,
1261  // the only difference to using _weight__ instead of g is
1262  // in setBranchAddresses which avoids calling makeROOTCompatible
1263  // So we have to check the behaviour using __weight__ as well
1264  general_options.m_weight_variable = "__weight__";
1265  MVA::ROOTDataset y(general_options);
1266 
1267  weights = y.getWeights();
1268  EXPECT_EQ(weights.size(), 10);
1269  EXPECT_FLOAT_EQ(weights[0], 1.2);
1270  EXPECT_FLOAT_EQ(weights[1], 2.2);
1271  EXPECT_FLOAT_EQ(weights[2], 3.2);
1272  EXPECT_FLOAT_EQ(weights[3], 4.2);
1273  EXPECT_FLOAT_EQ(weights[4], 5.2);
1274  EXPECT_FLOAT_EQ(weights[5], 1.2);
1275  EXPECT_FLOAT_EQ(weights[6], 2.2);
1276  EXPECT_FLOAT_EQ(weights[7], 3.2);
1277  EXPECT_FLOAT_EQ(weights[8], 4.2);
1278  EXPECT_FLOAT_EQ(weights[9], 5.2);
1279 
1280  // Check TChain expansion
1281  general_options.m_datafiles = {"datafile*.root"};
1282  {
1283  MVA::ROOTDataset chain_test(general_options);
1284  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1285  }
1286  boost::filesystem::copy_file("datafile.root", "datafile3.root");
1287  {
1288  MVA::ROOTDataset chain_test(general_options);
1289  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1290  }
1291  boost::filesystem::copy_file("datafile.root", "datafile4.root");
1292  {
1293  MVA::ROOTDataset chain_test(general_options);
1294  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1295  }
1296  // Test m_max_events feature
1297  {
1298  general_options.m_max_events = 10;
1299  MVA::ROOTDataset chain_test(general_options);
1300  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1301  general_options.m_max_events = 0;
1302  }
1303 
1304  // If a file exists with the specified expansion
1305  // the file takes precedence over the expansion
1306  boost::filesystem::copy_file("datafile.root", "datafile*.root");
1307  {
1308  general_options.m_max_events = 0;
1309  MVA::ROOTDataset chain_test(general_options);
1310  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1311  }
1312 
1313  }
1314  TEST(DatasetTest, ROOTMultiDatasetDouble)
1315  {
1316 
1318  TFile file("datafile.root", "RECREATE");
1319  file.cd();
1320  TTree tree("tree", "TreeTitle");
1321  double a, b, c, d, e, f, g, v, w = 0;
1322  tree.Branch("a", &a, "a/D");
1323  tree.Branch("b", &b, "b/D");
1324  tree.Branch("c", &c, "c/D");
1325  tree.Branch("d", &d, "d/D");
1326  tree.Branch("e__bo__bc", &e, "e__bo__bc/D");
1327  tree.Branch("f__bo__bc", &f, "f__bo__bc/D");
1328  tree.Branch("g", &g, "g/D");
1329  tree.Branch("__weight__", &c, "__weight__/D");
1330  tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
1331  tree.Branch("w", &w, "w/D");
1332 
1333  for (unsigned int i = 0; i < 5; ++i) {
1334  a = i + 1.0;
1335  b = i + 1.1;
1336  c = i + 1.2;
1337  d = i + 1.3;
1338  e = i + 1.4;
1339  f = i + 1.5;
1340  g = float(i % 2 == 0);
1341  w = i + 1.6;
1342  v = i + 1.7;
1343  tree.Fill();
1344  }
1345 
1346  file.Write("tree");
1347 
1348  TFile file2("datafile2.root", "RECREATE");
1349  file2.cd();
1350  TTree tree2("tree", "TreeTitle");
1351  tree2.Branch("a", &a);
1352  tree2.Branch("b", &b);
1353  tree2.Branch("c", &c);
1354  tree2.Branch("d", &d);
1355  tree2.Branch("e__bo__bc", &e);
1356  tree2.Branch("f__bo__bc", &f);
1357  tree2.Branch("g", &g);
1358  tree2.Branch("__weight__", &c);
1359  tree2.Branch("v__bo__bc", &v);
1360  tree2.Branch("w", &w);
1361 
1362  for (unsigned int i = 0; i < 5; ++i) {
1363  a = i + 1.0;
1364  b = i + 1.1;
1365  c = i + 1.2;
1366  d = i + 1.3;
1367  e = i + 1.4;
1368  f = i + 1.5;
1369  g = float(i % 2 == 0);
1370  w = i + 1.6;
1371  v = i + 1.7;
1372  tree2.Fill();
1373  }
1374 
1375  file2.Write("tree");
1376 
1377  MVA::GeneralOptions general_options;
1378  // Both names with and without makeROOTCompatible should work
1379  general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1380  general_options.m_spectators = {"w", "v()"};
1381  general_options.m_signal_class = 1;
1382  general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1383  general_options.m_treename = "tree";
1384  general_options.m_target_variable = "g";
1385  general_options.m_weight_variable = "c";
1386  MVA::ROOTDataset x(general_options);
1387 
1388  EXPECT_EQ(x.getNumberOfFeatures(), 4);
1389  EXPECT_EQ(x.getNumberOfSpectators(), 2);
1390  EXPECT_EQ(x.getNumberOfEvents(), 10);
1391 
1392  // Should just work
1393  x.loadEvent(0);
1394  EXPECT_EQ(x.m_input.size(), 4);
1395  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1396  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1397  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1398  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1399  EXPECT_EQ(x.m_spectators.size(), 2);
1400  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1401  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1402  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1403  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1404  EXPECT_EQ(x.m_isSignal, true);
1405 
1406  x.loadEvent(5);
1407  EXPECT_EQ(x.m_input.size(), 4);
1408  EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1409  EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1410  EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1411  EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1412  EXPECT_EQ(x.m_spectators.size(), 2);
1413  EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1414  EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1415  EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1416  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1417  EXPECT_EQ(x.m_isSignal, true);
1418 
1419  x.loadEvent(1);
1420  EXPECT_EQ(x.m_input.size(), 4);
1421  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1422  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1423  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1424  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1425  EXPECT_EQ(x.m_spectators.size(), 2);
1426  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1427  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1428  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1429  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1430  EXPECT_EQ(x.m_isSignal, false);
1431 
1432  x.loadEvent(6);
1433  EXPECT_EQ(x.m_input.size(), 4);
1434  EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1435  EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1436  EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1437  EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1438  EXPECT_EQ(x.m_spectators.size(), 2);
1439  EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1440  EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1441  EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1442  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1443  EXPECT_EQ(x.m_isSignal, false);
1444 
1445  x.loadEvent(2);
1446  EXPECT_EQ(x.m_input.size(), 4);
1447  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1448  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1449  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1450  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1451  EXPECT_EQ(x.m_spectators.size(), 2);
1452  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1453  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1454  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1455  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1456  EXPECT_EQ(x.m_isSignal, true);
1457 
1458  x.loadEvent(7);
1459  EXPECT_EQ(x.m_input.size(), 4);
1460  EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1461  EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1462  EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1463  EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1464  EXPECT_EQ(x.m_spectators.size(), 2);
1465  EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1466  EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1467  EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1468  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1469  EXPECT_EQ(x.m_isSignal, true);
1470 
1471  x.loadEvent(3);
1472  EXPECT_EQ(x.m_input.size(), 4);
1473  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1474  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1475  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1476  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1477  EXPECT_EQ(x.m_spectators.size(), 2);
1478  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1479  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1480  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1481  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1482  EXPECT_EQ(x.m_isSignal, false);
1483 
1484  x.loadEvent(8);
1485  EXPECT_EQ(x.m_input.size(), 4);
1486  EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1487  EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1488  EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1489  EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1490  EXPECT_EQ(x.m_spectators.size(), 2);
1491  EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1492  EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1493  EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1494  EXPECT_FLOAT_EQ(x.m_target, 0.0);
1495  EXPECT_EQ(x.m_isSignal, false);
1496 
1497  x.loadEvent(4);
1498  EXPECT_EQ(x.m_input.size(), 4);
1499  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1500  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1501  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1502  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1503  EXPECT_EQ(x.m_spectators.size(), 2);
1504  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1505  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1506  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1507  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1508  EXPECT_EQ(x.m_isSignal, true);
1509 
1510  x.loadEvent(9);
1511  EXPECT_EQ(x.m_input.size(), 4);
1512  EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1513  EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1514  EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1515  EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1516  EXPECT_EQ(x.m_spectators.size(), 2);
1517  EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1518  EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1519  EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1520  EXPECT_FLOAT_EQ(x.m_target, 1.0);
1521  EXPECT_EQ(x.m_isSignal, true);
1522 
1523  EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1524 
1525  auto feature = x.getFeature(1);
1526  EXPECT_EQ(feature.size(), 10);
1527  EXPECT_FLOAT_EQ(feature[0], 1.1);
1528  EXPECT_FLOAT_EQ(feature[1], 2.1);
1529  EXPECT_FLOAT_EQ(feature[2], 3.1);
1530  EXPECT_FLOAT_EQ(feature[3], 4.1);
1531  EXPECT_FLOAT_EQ(feature[4], 5.1);
1532  EXPECT_FLOAT_EQ(feature[5], 1.1);
1533  EXPECT_FLOAT_EQ(feature[6], 2.1);
1534  EXPECT_FLOAT_EQ(feature[7], 3.1);
1535  EXPECT_FLOAT_EQ(feature[8], 4.1);
1536  EXPECT_FLOAT_EQ(feature[9], 5.1);
1537 
1538  // Same result for mother class implementation
1539  feature = x.Dataset::getFeature(1);
1540  EXPECT_EQ(feature.size(), 10);
1541  EXPECT_FLOAT_EQ(feature[0], 1.1);
1542  EXPECT_FLOAT_EQ(feature[1], 2.1);
1543  EXPECT_FLOAT_EQ(feature[2], 3.1);
1544  EXPECT_FLOAT_EQ(feature[3], 4.1);
1545  EXPECT_FLOAT_EQ(feature[4], 5.1);
1546  EXPECT_FLOAT_EQ(feature[5], 1.1);
1547  EXPECT_FLOAT_EQ(feature[6], 2.1);
1548  EXPECT_FLOAT_EQ(feature[7], 3.1);
1549  EXPECT_FLOAT_EQ(feature[8], 4.1);
1550  EXPECT_FLOAT_EQ(feature[9], 5.1);
1551 
1552  auto spectator = x.getSpectator(1);
1553  EXPECT_EQ(spectator.size(), 10);
1554  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1555  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1556  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1557  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1558  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1559  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1560  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1561  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1562  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1563  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1564 
1565  // Same result for mother class implementation
1566  spectator = x.Dataset::getSpectator(1);
1567  EXPECT_EQ(spectator.size(), 10);
1568  EXPECT_FLOAT_EQ(spectator[0], 1.7);
1569  EXPECT_FLOAT_EQ(spectator[1], 2.7);
1570  EXPECT_FLOAT_EQ(spectator[2], 3.7);
1571  EXPECT_FLOAT_EQ(spectator[3], 4.7);
1572  EXPECT_FLOAT_EQ(spectator[4], 5.7);
1573  EXPECT_FLOAT_EQ(spectator[5], 1.7);
1574  EXPECT_FLOAT_EQ(spectator[6], 2.7);
1575  EXPECT_FLOAT_EQ(spectator[7], 3.7);
1576  EXPECT_FLOAT_EQ(spectator[8], 4.7);
1577  EXPECT_FLOAT_EQ(spectator[9], 5.7);
1578 
1579  auto weights = x.getWeights();
1580  EXPECT_EQ(weights.size(), 10);
1581  EXPECT_FLOAT_EQ(weights[0], 1.2);
1582  EXPECT_FLOAT_EQ(weights[1], 2.2);
1583  EXPECT_FLOAT_EQ(weights[2], 3.2);
1584  EXPECT_FLOAT_EQ(weights[3], 4.2);
1585  EXPECT_FLOAT_EQ(weights[4], 5.2);
1586  EXPECT_FLOAT_EQ(weights[5], 1.2);
1587  EXPECT_FLOAT_EQ(weights[6], 2.2);
1588  EXPECT_FLOAT_EQ(weights[7], 3.2);
1589  EXPECT_FLOAT_EQ(weights[8], 4.2);
1590  EXPECT_FLOAT_EQ(weights[9], 5.2);
1591 
1592  // Same result for mother class implementation
1593  weights = x.Dataset::getWeights();
1594  EXPECT_EQ(weights.size(), 10);
1595  EXPECT_FLOAT_EQ(weights[0], 1.2);
1596  EXPECT_FLOAT_EQ(weights[1], 2.2);
1597  EXPECT_FLOAT_EQ(weights[2], 3.2);
1598  EXPECT_FLOAT_EQ(weights[3], 4.2);
1599  EXPECT_FLOAT_EQ(weights[4], 5.2);
1600  EXPECT_FLOAT_EQ(weights[5], 1.2);
1601  EXPECT_FLOAT_EQ(weights[6], 2.2);
1602  EXPECT_FLOAT_EQ(weights[7], 3.2);
1603  EXPECT_FLOAT_EQ(weights[8], 4.2);
1604  EXPECT_FLOAT_EQ(weights[9], 5.2);
1605 
1606  auto targets = x.getTargets();
1607  EXPECT_EQ(targets.size(), 10);
1608  EXPECT_FLOAT_EQ(targets[0], 1.0);
1609  EXPECT_FLOAT_EQ(targets[1], 0.0);
1610  EXPECT_FLOAT_EQ(targets[2], 1.0);
1611  EXPECT_FLOAT_EQ(targets[3], 0.0);
1612  EXPECT_FLOAT_EQ(targets[4], 1.0);
1613  EXPECT_FLOAT_EQ(targets[5], 1.0);
1614  EXPECT_FLOAT_EQ(targets[6], 0.0);
1615  EXPECT_FLOAT_EQ(targets[7], 1.0);
1616  EXPECT_FLOAT_EQ(targets[8], 0.0);
1617  EXPECT_FLOAT_EQ(targets[9], 1.0);
1618 
1619  auto signals = x.getSignals();
1620  EXPECT_EQ(signals.size(), 10);
1621  EXPECT_EQ(signals[0], true);
1622  EXPECT_EQ(signals[1], false);
1623  EXPECT_EQ(signals[2], true);
1624  EXPECT_EQ(signals[3], false);
1625  EXPECT_EQ(signals[4], true);
1626  EXPECT_EQ(signals[5], true);
1627  EXPECT_EQ(signals[6], false);
1628  EXPECT_EQ(signals[7], true);
1629  EXPECT_EQ(signals[8], false);
1630  EXPECT_EQ(signals[9], true);
1631 
1632  // Using __weight__ should work as well,
1633  // the only difference to using _weight__ instead of g is
1634  // in setBranchAddresses which avoids calling makeROOTCompatible
1635  // So we have to check the behaviour using __weight__ as well
1636  general_options.m_weight_variable = "__weight__";
1637  MVA::ROOTDataset y(general_options);
1638 
1639  weights = y.getWeights();
1640  EXPECT_EQ(weights.size(), 10);
1641  EXPECT_FLOAT_EQ(weights[0], 1.2);
1642  EXPECT_FLOAT_EQ(weights[1], 2.2);
1643  EXPECT_FLOAT_EQ(weights[2], 3.2);
1644  EXPECT_FLOAT_EQ(weights[3], 4.2);
1645  EXPECT_FLOAT_EQ(weights[4], 5.2);
1646  EXPECT_FLOAT_EQ(weights[5], 1.2);
1647  EXPECT_FLOAT_EQ(weights[6], 2.2);
1648  EXPECT_FLOAT_EQ(weights[7], 3.2);
1649  EXPECT_FLOAT_EQ(weights[8], 4.2);
1650  EXPECT_FLOAT_EQ(weights[9], 5.2);
1651 
1652  // Check TChain expansion
1653  general_options.m_datafiles = {"datafile*.root"};
1654  {
1655  MVA::ROOTDataset chain_test(general_options);
1656  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1657  }
1658  boost::filesystem::copy_file("datafile.root", "datafile3.root");
1659  {
1660  MVA::ROOTDataset chain_test(general_options);
1661  EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1662  }
1663  boost::filesystem::copy_file("datafile.root", "datafile4.root");
1664  {
1665  MVA::ROOTDataset chain_test(general_options);
1666  EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1667  }
1668  // Test m_max_events feature
1669  {
1670  general_options.m_max_events = 10;
1671  MVA::ROOTDataset chain_test(general_options);
1672  EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1673  general_options.m_max_events = 0;
1674  }
1675 
1676  // If a file exists with the specified expansion
1677  // the file takes precedence over the expansion
1678  boost::filesystem::copy_file("datafile.root", "datafile*.root");
1679  {
1680  general_options.m_max_events = 0;
1681  MVA::ROOTDataset chain_test(general_options);
1682  EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1683  }
1684  }
1685 }
Belle2::MVA::MultiDataset
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:187
prepareAsicCrosstalkSimDB.e
e
aux.
Definition: prepareAsicCrosstalkSimDB.py:53
Belle2::MVA::Dataset
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:34
Belle2::TestHelpers::TempDirCreator
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:57
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::SubDataset
Wraps another Dataset and provides a view to a subset of its features and events.
Definition: Dataset.h:234
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::TEST
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Definition: utilityFunctions.cc:18
Belle2::MVA::SingleDataset
Wraps the data of a single event into a Dataset.
Definition: Dataset.h:136
Belle2::MVA::ROOTDataset
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
Definition: Dataset.h:349
Belle2::MVA::CombinedDataset
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Definition: Dataset.h:294