Belle II Software development
test_Dataset.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/interface/Dataset.h>
10#include <framework/utilities/TestHelpers.h>
11
12#include <gtest/gtest.h>
13
14#include <filesystem>
15#include <fstream>
16#include <numeric>
17
18using namespace Belle2;
19
20namespace {
21
22 class TestDataset : public MVA::Dataset {
23 public:
24 explicit TestDataset(MVA::GeneralOptions& general_options) : MVA::Dataset(general_options)
25 {
26 m_input = {1.0, 2.0, 3.0, 4.0, 5.0};
27 m_target = 3.0;
28 m_isSignal = true;
29 m_weight = -3.0;
30 }
31
32 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 5; }
33 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 2; }
34 [[nodiscard]] unsigned int getNumberOfEvents() const override { return 20; }
35 void loadEvent(unsigned int iEvent) override
36 {
37 auto f = static_cast<float>(iEvent);
38 m_input = {f + 1, f + 2, f + 3, f + 4, f + 5};
39 m_spectators = {f + 6, f + 7};
40 };
41 float getSignalFraction() override { return 0.1; };
42 std::vector<float> getFeature(unsigned int iFeature) override
43 {
44 std::vector<float> a(20, 0.0);
45 std::iota(a.begin(), a.end(), iFeature + 1);
46 return a;
47 }
48 std::vector<float> getSpectator(unsigned int iSpectator) override
49 {
50 std::vector<float> a(20, 0.0);
51 std::iota(a.begin(), a.end(), iSpectator + 6);
52 return a;
53 }
54
55 };
56
57 TEST(DatasetTest, SingleDataset)
58 {
59
60 MVA::GeneralOptions general_options;
61 general_options.m_variables = {"a", "b", "c"};
62 general_options.m_spectators = {"e", "f"};
63 general_options.m_signal_class = 2;
64 std::vector<float> input = {1.0, 2.0, 3.0};
65 std::vector<float> spectators = {4.0, 5.0};
66
67 MVA::SingleDataset x(general_options, input, 2.0, spectators);
68
69 EXPECT_EQ(x.getNumberOfFeatures(), 3);
70 EXPECT_EQ(x.getNumberOfSpectators(), 2);
71 EXPECT_EQ(x.getNumberOfEvents(), 1);
72
73 EXPECT_EQ(x.getFeatureIndex("a"), 0);
74 EXPECT_EQ(x.getFeatureIndex("b"), 1);
75 EXPECT_EQ(x.getFeatureIndex("c"), 2);
76 EXPECT_B2ERROR(x.getFeatureIndex("bla"));
77
78 EXPECT_EQ(x.getSpectatorIndex("e"), 0);
79 EXPECT_EQ(x.getSpectatorIndex("f"), 1);
80 EXPECT_B2ERROR(x.getSpectatorIndex("bla"));
81
82 // Should just work
83 x.loadEvent(0);
84
85 EXPECT_EQ(x.m_input.size(), 3);
86 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
87 EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
88 EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
89
90 EXPECT_EQ(x.m_spectators.size(), 2);
91 EXPECT_FLOAT_EQ(x.m_spectators[0], 4.0);
92 EXPECT_FLOAT_EQ(x.m_spectators[1], 5.0);
93
94 EXPECT_FLOAT_EQ(x.m_weight, 1.0);
95 EXPECT_FLOAT_EQ(x.m_target, 2.0);
96 EXPECT_EQ(x.m_isSignal, true);
97
98 EXPECT_FLOAT_EQ(x.getSignalFraction(), 1.0);
99
100 auto feature = x.getFeature(1);
101 EXPECT_EQ(feature.size(), 1);
102 EXPECT_FLOAT_EQ(feature[0], 2.0);
103
104 auto spectator = x.getSpectator(1);
105 EXPECT_EQ(spectator.size(), 1);
106 EXPECT_FLOAT_EQ(spectator[0], 5.0);
107
108 // Same result for mother class implementation
109 feature = x.Dataset::getFeature(1);
110 EXPECT_EQ(feature.size(), 1);
111 EXPECT_FLOAT_EQ(feature[0], 2.0);
112
113 }
114
115 TEST(DatasetTest, MultiDataset)
116 {
117
118 MVA::GeneralOptions general_options;
119 general_options.m_variables = {"a", "b", "c"};
120 general_options.m_spectators = {"e", "f"};
121 general_options.m_signal_class = 2;
122 std::vector<std::vector<float>> matrix = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}};
123 std::vector<std::vector<float>> spectator_matrix = {{12.0, 13.0}, {15.0, 16.0}, {18.0, 19.0}};
124 std::vector<float> targets = {2.0, 0.0, 2.0};
125 std::vector<float> weights = {1.0, 2.0, 3.0};
126
127 EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, {1.0}, weights));
128
129 EXPECT_B2ERROR(MVA::MultiDataset(general_options, matrix, spectator_matrix, targets, {1.0}));
130
131 MVA::MultiDataset x(general_options, matrix, spectator_matrix, targets, weights);
132
133 EXPECT_EQ(x.getNumberOfFeatures(), 3);
134 EXPECT_EQ(x.getNumberOfEvents(), 3);
135
136 // Should just work
137 x.loadEvent(0);
138
139 EXPECT_EQ(x.m_input.size(), 3);
140 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
141 EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
142 EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
143
144 EXPECT_EQ(x.m_spectators.size(), 2);
145 EXPECT_FLOAT_EQ(x.m_spectators[0], 12.0);
146 EXPECT_FLOAT_EQ(x.m_spectators[1], 13.0);
147
148 EXPECT_FLOAT_EQ(x.m_weight, 1.0);
149 EXPECT_FLOAT_EQ(x.m_target, 2.0);
150 EXPECT_EQ(x.m_isSignal, true);
151
152 // Should just work
153 x.loadEvent(1);
154
155 EXPECT_EQ(x.m_input.size(), 3);
156 EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
157 EXPECT_FLOAT_EQ(x.m_input[1], 5.0);
158 EXPECT_FLOAT_EQ(x.m_input[2], 6.0);
159
160 EXPECT_EQ(x.m_spectators.size(), 2);
161 EXPECT_FLOAT_EQ(x.m_spectators[0], 15.0);
162 EXPECT_FLOAT_EQ(x.m_spectators[1], 16.0);
163
164 EXPECT_FLOAT_EQ(x.m_weight, 2.0);
165 EXPECT_FLOAT_EQ(x.m_target, 0.0);
166 EXPECT_EQ(x.m_isSignal, false);
167
168 // Should just work
169 x.loadEvent(2);
170
171 EXPECT_EQ(x.m_input.size(), 3);
172 EXPECT_FLOAT_EQ(x.m_input[0], 7.0);
173 EXPECT_FLOAT_EQ(x.m_input[1], 8.0);
174 EXPECT_FLOAT_EQ(x.m_input[2], 9.0);
175
176 EXPECT_EQ(x.m_spectators.size(), 2);
177 EXPECT_FLOAT_EQ(x.m_spectators[0], 18.0);
178 EXPECT_FLOAT_EQ(x.m_spectators[1], 19.0);
179
180 EXPECT_FLOAT_EQ(x.m_weight, 3.0);
181 EXPECT_FLOAT_EQ(x.m_target, 2.0);
182 EXPECT_EQ(x.m_isSignal, true);
183
184 EXPECT_FLOAT_EQ(x.getSignalFraction(), 4.0 / 6.0);
185
186 auto feature = x.getFeature(1);
187 EXPECT_EQ(feature.size(), 3);
188 EXPECT_FLOAT_EQ(feature[0], 2.0);
189 EXPECT_FLOAT_EQ(feature[1], 5.0);
190 EXPECT_FLOAT_EQ(feature[2], 8.0);
191
192 auto spectator = x.getSpectator(1);
193 EXPECT_EQ(spectator.size(), 3);
194 EXPECT_FLOAT_EQ(spectator[0], 13.0);
195 EXPECT_FLOAT_EQ(spectator[1], 16.0);
196 EXPECT_FLOAT_EQ(spectator[2], 19.0);
197
198
199 }
200
201 TEST(DatasetTest, SubDataset)
202 {
203
204 MVA::GeneralOptions general_options;
205 general_options.m_signal_class = 3;
206 general_options.m_variables = {"a", "b", "c", "d", "e"};
207 general_options.m_spectators = {"f", "g"};
208 TestDataset test_dataset(general_options);
209
210 general_options.m_variables = {"a", "d", "e"};
211 general_options.m_spectators = {"g"};
212 std::vector<bool> events = {true, false, true, false, true, false, true, false, true, false,
213 true, false, true, false, true, false, true, false, true, false
214 };
215 MVA::SubDataset x(general_options, events, test_dataset);
216
217 EXPECT_EQ(x.getNumberOfFeatures(), 3);
218 EXPECT_EQ(x.getNumberOfEvents(), 10);
219
220 // Should just work
221 x.loadEvent(0);
222
223 EXPECT_EQ(x.m_input.size(), 3);
224 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
225 EXPECT_FLOAT_EQ(x.m_input[1], 4.0);
226 EXPECT_FLOAT_EQ(x.m_input[2], 5.0);
227
228 EXPECT_EQ(x.m_spectators.size(), 1);
229 EXPECT_FLOAT_EQ(x.m_spectators[0], 7.0);
230
231 EXPECT_FLOAT_EQ(x.m_weight, -3.0);
232 EXPECT_FLOAT_EQ(x.m_target, 3.0);
233 EXPECT_EQ(x.m_isSignal, true);
234
235 EXPECT_FLOAT_EQ(x.getSignalFraction(), 1);
236
237 auto feature = x.getFeature(1);
238 EXPECT_EQ(feature.size(), 10);
239 for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
240 EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
241 };
242
243 auto spectator = x.getSpectator(0);
244 EXPECT_EQ(spectator.size(), 10);
245 for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
246 EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
247 };
248
249 // Same result for mother class implementation
250 feature = x.Dataset::getFeature(1);
251 EXPECT_EQ(feature.size(), 10);
252 for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
253 EXPECT_FLOAT_EQ(feature[iEvent], iEvent * 2 + 4);
254 };
255
256 spectator = x.Dataset::getSpectator(0);
257 EXPECT_EQ(spectator.size(), 10);
258 for (unsigned int iEvent = 0; iEvent < 10; ++iEvent) {
259 EXPECT_FLOAT_EQ(spectator[iEvent], iEvent * 2 + 7);
260 };
261
262 // Test without event indices
263 MVA::SubDataset y(general_options, {}, test_dataset);
264 feature = y.getFeature(1);
265 EXPECT_EQ(feature.size(), 20);
266 for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
267 EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
268 };
269
270 spectator = y.getSpectator(0);
271 EXPECT_EQ(spectator.size(), 20);
272 for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
273 EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
274 };
275
276 // Same result for mother class implementation
277 feature = y.Dataset::getFeature(1);
278 EXPECT_EQ(feature.size(), 20);
279 for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
280 EXPECT_FLOAT_EQ(feature[iEvent], iEvent + 4);
281 };
282
283 spectator = y.Dataset::getSpectator(0);
284 EXPECT_EQ(spectator.size(), 20);
285 for (unsigned int iEvent = 0; iEvent < 20; ++iEvent) {
286 EXPECT_FLOAT_EQ(spectator[iEvent], iEvent + 7);
287 };
288
289 general_options.m_variables = {"a", "d", "e", "DOESNOTEXIST"};
290 try {
291 EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
292 } catch (...) {
293
294 }
295 EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
296
297 general_options.m_variables = {"a", "d", "e"};
298 general_options.m_spectators = {"DOESNOTEXIST"};
299 try {
300 EXPECT_B2ERROR(MVA::SubDataset(general_options, events, test_dataset));
301 } catch (...) {
302
303 }
304 EXPECT_THROW(MVA::SubDataset(general_options, events, test_dataset), std::runtime_error);
305
306 }
307
308 TEST(DatasetTest, CombinedDataset)
309 {
310
311 MVA::GeneralOptions general_options;
312 general_options.m_signal_class = 1;
313 general_options.m_variables = {"a", "b", "c", "d", "e"};
314 general_options.m_spectators = {"f", "g"};
315 TestDataset signal_dataset(general_options);
316 TestDataset bckgrd_dataset(general_options);
317
318 MVA::CombinedDataset x(general_options, signal_dataset, bckgrd_dataset);
319
320 EXPECT_EQ(x.getNumberOfFeatures(), 5);
321 EXPECT_EQ(x.getNumberOfEvents(), 40);
322
323 // Should just work
324 x.loadEvent(0);
325
326 EXPECT_EQ(x.m_input.size(), 5);
327 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
328 EXPECT_FLOAT_EQ(x.m_input[1], 2.0);
329 EXPECT_FLOAT_EQ(x.m_input[2], 3.0);
330 EXPECT_FLOAT_EQ(x.m_input[3], 4.0);
331 EXPECT_FLOAT_EQ(x.m_input[4], 5.0);
332
333 EXPECT_EQ(x.m_spectators.size(), 2);
334 EXPECT_FLOAT_EQ(x.m_spectators[0], 6.0);
335 EXPECT_FLOAT_EQ(x.m_spectators[1], 7.0);
336
337 EXPECT_FLOAT_EQ(x.m_weight, -3.0);
338 EXPECT_FLOAT_EQ(x.m_target, 1.0);
339 EXPECT_EQ(x.m_isSignal, true);
340
341 EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.5);
342
343 auto feature = x.getFeature(1);
344 EXPECT_EQ(feature.size(), 40);
345 for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
346 EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
347 };
348
349 auto spectator = x.getSpectator(0);
350 EXPECT_EQ(spectator.size(), 40);
351 for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
352 EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
353 };
354
355 // Same result for mother class implementation
356 feature = x.Dataset::getFeature(1);
357 EXPECT_EQ(feature.size(), 40);
358 for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
359 EXPECT_FLOAT_EQ(feature[iEvent], (iEvent % 20) + 2);
360 };
361
362 spectator = x.Dataset::getSpectator(0);
363 EXPECT_EQ(spectator.size(), 40);
364 for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
365 EXPECT_FLOAT_EQ(spectator[iEvent], (iEvent % 20) + 6);
366 };
367
368 for (unsigned int iEvent = 0; iEvent < 40; ++iEvent) {
369 x.loadEvent(iEvent);
370 EXPECT_EQ(x.m_isSignal, iEvent < 20);
371 EXPECT_FLOAT_EQ(x.m_target, iEvent < 20 ? 1.0 : 0.0);
372 }
373
374 }
375
376 TEST(DatasetTest, ROOTDataset)
377 {
378
380 TFile file("datafile.root", "RECREATE");
381 file.cd();
382 TTree tree("tree", "TreeTitle");
383 float a, b, c, d, e, f, v, w = 0;
384 bool target;
385 tree.Branch("a", &a);
386 tree.Branch("b", &b);
387 tree.Branch("c", &c);
388 tree.Branch("d", &d);
389 tree.Branch("e__bo__bc", &e);
390 tree.Branch("f__bo__bc", &f);
391 tree.Branch("g", &target);
392 tree.Branch("__weight__", &c);
393 tree.Branch("v__bo__bc", &v);
394 tree.Branch("w", &w);
395
396 for (unsigned int i = 0; i < 5; ++i) {
397 a = i + 1.0;
398 b = i + 1.1;
399 c = i + 1.2;
400 d = i + 1.3;
401 e = i + 1.4;
402 f = i + 1.5;
403 target = (i % 2 == 0);
404 w = i + 1.6;
405 v = i + 1.7;
406 tree.Fill();
407 }
408
409 file.Write("tree");
410
411 MVA::GeneralOptions general_options;
412 // Both names with and without makeROOTCompatible should work
413 general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
414 general_options.m_spectators = {"w", "v()"};
415 general_options.m_signal_class = 1;
416 general_options.m_datafiles = {"datafile.root"};
417 general_options.m_treename = "tree";
418 general_options.m_target_variable = "g";
419 general_options.m_weight_variable = "c";
420 MVA::ROOTDataset x(general_options);
421
422 EXPECT_EQ(x.getNumberOfFeatures(), 4);
423 EXPECT_EQ(x.getNumberOfSpectators(), 2);
424 EXPECT_EQ(x.getNumberOfEvents(), 5);
425
426 // Should just work
427 x.loadEvent(0);
428 EXPECT_EQ(x.m_input.size(), 4);
429 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
430 EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
431 EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
432 EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
433 EXPECT_EQ(x.m_spectators.size(), 2);
434 EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
435 EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
436 EXPECT_FLOAT_EQ(x.m_weight, 1.2);
437 EXPECT_EQ(x.m_target, true);
438 EXPECT_EQ(x.m_isSignal, true);
439
440 x.loadEvent(1);
441 EXPECT_EQ(x.m_input.size(), 4);
442 EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
443 EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
444 EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
445 EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
446 EXPECT_EQ(x.m_spectators.size(), 2);
447 EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
448 EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
449 EXPECT_FLOAT_EQ(x.m_weight, 2.2);
450 EXPECT_EQ(x.m_target, false);
451 EXPECT_EQ(x.m_isSignal, false);
452
453 x.loadEvent(2);
454 EXPECT_EQ(x.m_input.size(), 4);
455 EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
456 EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
457 EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
458 EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
459 EXPECT_EQ(x.m_spectators.size(), 2);
460 EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
461 EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
462 EXPECT_FLOAT_EQ(x.m_weight, 3.2);
463 EXPECT_EQ(x.m_target, true);
464 EXPECT_EQ(x.m_isSignal, true);
465
466 x.loadEvent(3);
467 EXPECT_EQ(x.m_input.size(), 4);
468 EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
469 EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
470 EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
471 EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
472 EXPECT_EQ(x.m_spectators.size(), 2);
473 EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
474 EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
475 EXPECT_FLOAT_EQ(x.m_weight, 4.2);
476 EXPECT_EQ(x.m_target, false);
477 EXPECT_EQ(x.m_isSignal, false);
478
479 x.loadEvent(4);
480 EXPECT_EQ(x.m_input.size(), 4);
481 EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
482 EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
483 EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
484 EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
485 EXPECT_EQ(x.m_spectators.size(), 2);
486 EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
487 EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
488 EXPECT_FLOAT_EQ(x.m_weight, 5.2);
489 EXPECT_EQ(x.m_target, true);
490 EXPECT_EQ(x.m_isSignal, true);
491
492 EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
493
494 auto feature = x.getFeature(1);
495 EXPECT_EQ(feature.size(), 5);
496 EXPECT_FLOAT_EQ(feature[0], 1.1);
497 EXPECT_FLOAT_EQ(feature[1], 2.1);
498 EXPECT_FLOAT_EQ(feature[2], 3.1);
499 EXPECT_FLOAT_EQ(feature[3], 4.1);
500 EXPECT_FLOAT_EQ(feature[4], 5.1);
501
502 // Same result for mother class implementation
503 feature = x.Dataset::getFeature(1);
504 EXPECT_EQ(feature.size(), 5);
505 EXPECT_FLOAT_EQ(feature[0], 1.1);
506 EXPECT_FLOAT_EQ(feature[1], 2.1);
507 EXPECT_FLOAT_EQ(feature[2], 3.1);
508 EXPECT_FLOAT_EQ(feature[3], 4.1);
509 EXPECT_FLOAT_EQ(feature[4], 5.1);
510
511 auto spectator = x.getSpectator(1);
512 EXPECT_EQ(spectator.size(), 5);
513 EXPECT_FLOAT_EQ(spectator[0], 1.7);
514 EXPECT_FLOAT_EQ(spectator[1], 2.7);
515 EXPECT_FLOAT_EQ(spectator[2], 3.7);
516 EXPECT_FLOAT_EQ(spectator[3], 4.7);
517 EXPECT_FLOAT_EQ(spectator[4], 5.7);
518
519 // Same result for mother class implementation
520 spectator = x.Dataset::getSpectator(1);
521 EXPECT_EQ(spectator.size(), 5);
522 EXPECT_FLOAT_EQ(spectator[0], 1.7);
523 EXPECT_FLOAT_EQ(spectator[1], 2.7);
524 EXPECT_FLOAT_EQ(spectator[2], 3.7);
525 EXPECT_FLOAT_EQ(spectator[3], 4.7);
526 EXPECT_FLOAT_EQ(spectator[4], 5.7);
527
528 auto weights = x.getWeights();
529 EXPECT_EQ(weights.size(), 5);
530 EXPECT_FLOAT_EQ(weights[0], 1.2);
531 EXPECT_FLOAT_EQ(weights[1], 2.2);
532 EXPECT_FLOAT_EQ(weights[2], 3.2);
533 EXPECT_FLOAT_EQ(weights[3], 4.2);
534 EXPECT_FLOAT_EQ(weights[4], 5.2);
535
536 // Same result for mother class implementation
537 weights = x.Dataset::getWeights();
538 EXPECT_EQ(weights.size(), 5);
539 EXPECT_FLOAT_EQ(weights[0], 1.2);
540 EXPECT_FLOAT_EQ(weights[1], 2.2);
541 EXPECT_FLOAT_EQ(weights[2], 3.2);
542 EXPECT_FLOAT_EQ(weights[3], 4.2);
543 EXPECT_FLOAT_EQ(weights[4], 5.2);
544
545 auto targets = x.getTargets();
546 EXPECT_EQ(targets.size(), 5);
547 EXPECT_EQ(targets[0], true);
548 EXPECT_EQ(targets[1], false);
549 EXPECT_EQ(targets[2], true);
550 EXPECT_EQ(targets[3], false);
551 EXPECT_EQ(targets[4], true);
552
553 auto signals = x.getSignals();
554 EXPECT_EQ(signals.size(), 5);
555 EXPECT_EQ(signals[0], true);
556 EXPECT_EQ(signals[1], false);
557 EXPECT_EQ(signals[2], true);
558 EXPECT_EQ(signals[3], false);
559 EXPECT_EQ(signals[4], true);
560
561 // Using __weight__ should work as well,
562 // the only difference to using _weight__ instead of c is
563 // in setBranchAddresses which avoids calling makeROOTCompatible
564 // So we have to check the behaviour using __weight__ as well
565 general_options.m_weight_variable = "__weight__";
566 MVA::ROOTDataset y(general_options);
567
568 weights = y.getWeights();
569 EXPECT_EQ(weights.size(), 5);
570 EXPECT_FLOAT_EQ(weights[0], 1.2);
571 EXPECT_FLOAT_EQ(weights[1], 2.2);
572 EXPECT_FLOAT_EQ(weights[2], 3.2);
573 EXPECT_FLOAT_EQ(weights[3], 4.2);
574 EXPECT_FLOAT_EQ(weights[4], 5.2);
575
576 // Check TChain expansion
577 general_options.m_datafiles = {"datafile*.root"};
578 {
579 MVA::ROOTDataset chain_test(general_options);
580 EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
581 }
582 std::filesystem::copy_file("datafile.root", "datafile2.root");
583 {
584 MVA::ROOTDataset chain_test(general_options);
585 EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
586 }
587 std::filesystem::copy_file("datafile.root", "datafile3.root");
588 {
589 MVA::ROOTDataset chain_test(general_options);
590 EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
591 }
592 // Test m_max_events feature
593 {
594 general_options.m_max_events = 10;
595 MVA::ROOTDataset chain_test(general_options);
596 EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
597 general_options.m_max_events = 0;
598 }
599
600 // Check for missing tree
601 general_options.m_treename = "missing tree";
602 try {
603 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
604 } catch (...) {
605
606 }
607 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
608
609 // Check for missing branch
610 general_options.m_treename = "tree";
611 general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
612 try {
613 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
614 } catch (...) {
615
616 }
617 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
618
619 // Check for missing branch
620 general_options.m_treename = "tree";
621 general_options.m_variables = {"a", "b", "e", "f"};
622 general_options.m_spectators = {"missing branch"};
623 try {
624 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
625 } catch (...) {
626
627 }
628 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
629
630 // Check for missing file
631 general_options.m_spectators = {};
632 general_options.m_datafiles = {"DOESNOTEXIST.root"};
633 general_options.m_treename = "tree";
634 try {
635 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
636 } catch (...) {
637
638 }
639 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
640
641 // Check for invalid file
642 general_options.m_datafiles = {"ISNotAValidROOTFile"};
643 general_options.m_treename = "tree";
644
645 {
646 std::ofstream(general_options.m_datafiles[0]);
647 }
648 EXPECT_TRUE(std::filesystem::exists(general_options.m_datafiles[0]));
649
650 try {
651 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
652 } catch (...) {
653
654 }
655 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
656 }
657
658
659 TEST(DatasetTest, ROOTDatasetMixedTypes)
660 {
661
663 TFile file("datafile.root", "RECREATE");
664 file.cd();
665 TTree tree("tree", "TreeTitle");
666 double a, c, d, g, v, w = 0;
667 float b = 0.0f;
668 int e = 0;
669 bool f = false;
670 tree.Branch("a", &a, "a/D");
671 tree.Branch("b", &b, "b/F");
672 tree.Branch("c", &c, "c/D");
673 tree.Branch("d", &d, "d/D");
674 tree.Branch("e__bo__bc", &e, "e__bo__bc/I");
675 tree.Branch("f__bo__bc", &f, "f__bo__bc/O");
676 tree.Branch("g", &g, "g/D");
677 tree.Branch("__weight__", &c, "__weight__/D");
678 tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
679 tree.Branch("w", &w, "w/D");
680
681 for (unsigned int i = 0; i < 5; ++i) {
682 a = i + 1.0;
683 b = i + 1.1;
684 c = i + 1.2;
685 d = i + 1.3;
686 e = i + 1;
687 f = i % 2 == 0;
688 g = float(i % 2 == 0);
689 w = i + 1.6;
690 v = i + 1.7;
691 tree.Fill();
692 }
693
694 file.Write("tree");
695
696 MVA::GeneralOptions general_options;
697 // Both names with and without makeROOTCompatible should work
698 general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
699 general_options.m_spectators = {"w", "v()"};
700 general_options.m_signal_class = 1;
701 general_options.m_datafiles = {"datafile.root"};
702 general_options.m_treename = "tree";
703 general_options.m_target_variable = "g";
704 general_options.m_weight_variable = "c";
705 MVA::ROOTDataset x(general_options);
706
707 EXPECT_EQ(x.getNumberOfFeatures(), 4);
708 EXPECT_EQ(x.getNumberOfSpectators(), 2);
709 EXPECT_EQ(x.getNumberOfEvents(), 5);
710
711 // Should just work
712 x.loadEvent(0);
713 EXPECT_EQ(x.m_input.size(), 4);
714 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
715 EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
716 EXPECT_EQ(x.m_input[2], 1);
717 EXPECT_EQ(x.m_input[3], true);
718 EXPECT_EQ(x.m_spectators.size(), 2);
719 EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
720 EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
721 EXPECT_FLOAT_EQ(x.m_weight, 1.2);
722 EXPECT_FLOAT_EQ(x.m_target, 1.0);
723 EXPECT_EQ(x.m_isSignal, true);
724
725 x.loadEvent(1);
726 EXPECT_EQ(x.m_input.size(), 4);
727 EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
728 EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
729 EXPECT_EQ(x.m_input[2], 2);
730 EXPECT_EQ(x.m_input[3], false);
731 EXPECT_EQ(x.m_spectators.size(), 2);
732 EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
733 EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
734 EXPECT_FLOAT_EQ(x.m_weight, 2.2);
735 EXPECT_FLOAT_EQ(x.m_target, 0.0);
736 EXPECT_EQ(x.m_isSignal, false);
737
738 x.loadEvent(2);
739 EXPECT_EQ(x.m_input.size(), 4);
740 EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
741 EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
742 EXPECT_EQ(x.m_input[2], 3);
743 EXPECT_EQ(x.m_input[3], true);
744 EXPECT_EQ(x.m_spectators.size(), 2);
745 EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
746 EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
747 EXPECT_FLOAT_EQ(x.m_weight, 3.2);
748 EXPECT_FLOAT_EQ(x.m_target, 1.0);
749 EXPECT_EQ(x.m_isSignal, true);
750
751 x.loadEvent(3);
752 EXPECT_EQ(x.m_input.size(), 4);
753 EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
754 EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
755 EXPECT_EQ(x.m_input[2], 4);
756 EXPECT_EQ(x.m_input[3], false);
757 EXPECT_EQ(x.m_spectators.size(), 2);
758 EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
759 EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
760 EXPECT_FLOAT_EQ(x.m_weight, 4.2);
761 EXPECT_FLOAT_EQ(x.m_target, 0.0);
762 EXPECT_EQ(x.m_isSignal, false);
763
764 x.loadEvent(4);
765 EXPECT_EQ(x.m_input.size(), 4);
766 EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
767 EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
768 EXPECT_EQ(x.m_input[2], 5);
769 EXPECT_EQ(x.m_input[3], true);
770 EXPECT_EQ(x.m_spectators.size(), 2);
771 EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
772 EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
773 EXPECT_FLOAT_EQ(x.m_weight, 5.2);
774 EXPECT_FLOAT_EQ(x.m_target, 1.0);
775 EXPECT_EQ(x.m_isSignal, true);
776
777 EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
778
779 auto feature = x.getFeature(1);
780 EXPECT_EQ(feature.size(), 5);
781 EXPECT_FLOAT_EQ(feature[0], 1.1);
782 EXPECT_FLOAT_EQ(feature[1], 2.1);
783 EXPECT_FLOAT_EQ(feature[2], 3.1);
784 EXPECT_FLOAT_EQ(feature[3], 4.1);
785 EXPECT_FLOAT_EQ(feature[4], 5.1);
786
787 // Same result for mother class implementation
788 feature = x.Dataset::getFeature(1);
789 EXPECT_EQ(feature.size(), 5);
790 EXPECT_FLOAT_EQ(feature[0], 1.1);
791 EXPECT_FLOAT_EQ(feature[1], 2.1);
792 EXPECT_FLOAT_EQ(feature[2], 3.1);
793 EXPECT_FLOAT_EQ(feature[3], 4.1);
794 EXPECT_FLOAT_EQ(feature[4], 5.1);
795
796 auto spectator = x.getSpectator(1);
797 EXPECT_EQ(spectator.size(), 5);
798 EXPECT_FLOAT_EQ(spectator[0], 1.7);
799 EXPECT_FLOAT_EQ(spectator[1], 2.7);
800 EXPECT_FLOAT_EQ(spectator[2], 3.7);
801 EXPECT_FLOAT_EQ(spectator[3], 4.7);
802 EXPECT_FLOAT_EQ(spectator[4], 5.7);
803
804 // Same result for mother class implementation
805 spectator = x.Dataset::getSpectator(1);
806 EXPECT_EQ(spectator.size(), 5);
807 EXPECT_FLOAT_EQ(spectator[0], 1.7);
808 EXPECT_FLOAT_EQ(spectator[1], 2.7);
809 EXPECT_FLOAT_EQ(spectator[2], 3.7);
810 EXPECT_FLOAT_EQ(spectator[3], 4.7);
811 EXPECT_FLOAT_EQ(spectator[4], 5.7);
812
813 auto weights = x.getWeights();
814 EXPECT_EQ(weights.size(), 5);
815 EXPECT_FLOAT_EQ(weights[0], 1.2);
816 EXPECT_FLOAT_EQ(weights[1], 2.2);
817 EXPECT_FLOAT_EQ(weights[2], 3.2);
818 EXPECT_FLOAT_EQ(weights[3], 4.2);
819 EXPECT_FLOAT_EQ(weights[4], 5.2);
820
821 // Same result for mother class implementation
822 weights = x.Dataset::getWeights();
823 EXPECT_EQ(weights.size(), 5);
824 EXPECT_FLOAT_EQ(weights[0], 1.2);
825 EXPECT_FLOAT_EQ(weights[1], 2.2);
826 EXPECT_FLOAT_EQ(weights[2], 3.2);
827 EXPECT_FLOAT_EQ(weights[3], 4.2);
828 EXPECT_FLOAT_EQ(weights[4], 5.2);
829
830 auto targets = x.getTargets();
831 EXPECT_EQ(targets.size(), 5);
832 EXPECT_FLOAT_EQ(targets[0], 1.0);
833 EXPECT_FLOAT_EQ(targets[1], 0.0);
834 EXPECT_FLOAT_EQ(targets[2], 1.0);
835 EXPECT_FLOAT_EQ(targets[3], 0.0);
836 EXPECT_FLOAT_EQ(targets[4], 1.0);
837
838 auto signals = x.getSignals();
839 EXPECT_EQ(signals.size(), 5);
840 EXPECT_EQ(signals[0], true);
841 EXPECT_EQ(signals[1], false);
842 EXPECT_EQ(signals[2], true);
843 EXPECT_EQ(signals[3], false);
844 EXPECT_EQ(signals[4], true);
845
846 // Using __weight__ should work as well,
847 // the only difference to using _weight__ instead of g is
848 // in setBranchAddresses which avoids calling makeROOTCompatible
849 // So we have to check the behaviour using __weight__ as well
850 general_options.m_weight_variable = "__weight__";
851 MVA::ROOTDataset y(general_options);
852
853 weights = y.getWeights();
854 EXPECT_EQ(weights.size(), 5);
855 EXPECT_FLOAT_EQ(weights[0], 1.2);
856 EXPECT_FLOAT_EQ(weights[1], 2.2);
857 EXPECT_FLOAT_EQ(weights[2], 3.2);
858 EXPECT_FLOAT_EQ(weights[3], 4.2);
859 EXPECT_FLOAT_EQ(weights[4], 5.2);
860
861 // Check TChain expansion
862 general_options.m_datafiles = {"datafile*.root"};
863 {
864 MVA::ROOTDataset chain_test(general_options);
865 EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
866 }
867 std::filesystem::copy_file("datafile.root", "datafile2.root");
868 {
869 MVA::ROOTDataset chain_test(general_options);
870 EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
871 }
872 std::filesystem::copy_file("datafile.root", "datafile3.root");
873 {
874 MVA::ROOTDataset chain_test(general_options);
875 EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
876 }
877 // Test m_max_events feature
878 {
879 general_options.m_max_events = 10;
880 MVA::ROOTDataset chain_test(general_options);
881 EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
882 general_options.m_max_events = 0;
883 }
884
885 // Check for missing tree
886 general_options.m_treename = "missing tree";
887 try {
888 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
889 } catch (...) {
890
891 }
892 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
893
894 // Check for missing branch
895 general_options.m_treename = "tree";
896 general_options.m_variables = {"a", "b", "e", "f", "missing branch"};
897 try {
898 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
899 } catch (...) {
900
901 }
902 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
903
904 // Check for missing branch
905 general_options.m_treename = "tree";
906 general_options.m_variables = {"a", "b", "e", "f"};
907 general_options.m_spectators = {"missing branch"};
908 try {
909 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
910 } catch (...) {
911
912 }
913 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
914
915 // Check for missing file
916 general_options.m_spectators = {};
917 general_options.m_datafiles = {"DOESNOTEXIST.root"};
918 general_options.m_treename = "tree";
919 try {
920 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
921 } catch (...) {
922
923 }
924 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
925
926 // Check for invalid file
927 general_options.m_datafiles = {"ISNotAValidROOTFile"};
928 general_options.m_treename = "tree";
929
930 {
931 std::ofstream(general_options.m_datafiles[0]);
932 }
933 EXPECT_TRUE(std::filesystem::exists(general_options.m_datafiles[0]));
934
935 try {
936 EXPECT_B2ERROR(MVA::ROOTDataset{general_options});
937 } catch (...) {
938
939 }
940 EXPECT_THROW(MVA::ROOTDataset{general_options}, std::runtime_error);
941 }
942
943 TEST(DatasetTest, ROOTMultiDataset)
944 {
945
947 TFile file("datafile.root", "RECREATE");
948 file.cd();
949 TTree tree("tree", "TreeTitle");
950 double a, b, c, d, e, f, v, w = 0;
951 bool target;
952 tree.Branch("a", &a);
953 tree.Branch("b", &b);
954 tree.Branch("c", &c);
955 tree.Branch("d", &d);
956 tree.Branch("e__bo__bc", &e);
957 tree.Branch("f__bo__bc", &f);
958 tree.Branch("g", &target);
959 tree.Branch("__weight__", &c);
960 tree.Branch("v__bo__bc", &v);
961 tree.Branch("w", &w);
962
963 for (unsigned int i = 0; i < 5; ++i) {
964 a = i + 1.0;
965 b = i + 1.1;
966 c = i + 1.2;
967 d = i + 1.3;
968 e = i + 1.4;
969 f = i + 1.5;
970 target = (i % 2 == 0);
971 w = i + 1.6;
972 v = i + 1.7;
973 tree.Fill();
974 }
975
976 file.Write("tree");
977
978 TFile file2("datafile2.root", "RECREATE");
979 file2.cd();
980 TTree tree2("tree", "TreeTitle");
981 tree2.Branch("a", &a);
982 tree2.Branch("b", &b);
983 tree2.Branch("c", &c);
984 tree2.Branch("d", &d);
985 tree2.Branch("e__bo__bc", &e);
986 tree2.Branch("f__bo__bc", &f);
987 tree2.Branch("g", &target);
988 tree2.Branch("__weight__", &c);
989 tree2.Branch("v__bo__bc", &v);
990 tree2.Branch("w", &w);
991
992 for (unsigned int i = 0; i < 5; ++i) {
993 a = i + 1.0;
994 b = i + 1.1;
995 c = i + 1.2;
996 d = i + 1.3;
997 e = i + 1.4;
998 f = i + 1.5;
999 target = (i % 2 == 0);
1000 w = i + 1.6;
1001 v = i + 1.7;
1002 tree2.Fill();
1003 }
1004
1005 file2.Write("tree");
1006
1007 MVA::GeneralOptions general_options;
1008 // Both names with and without makeROOTCompatible should work
1009 general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1010 general_options.m_spectators = {"w", "v()"};
1011 general_options.m_signal_class = 1;
1012 general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1013 general_options.m_treename = "tree";
1014 general_options.m_target_variable = "g";
1015 general_options.m_weight_variable = "c";
1016 MVA::ROOTDataset x(general_options);
1017
1018 EXPECT_EQ(x.getNumberOfFeatures(), 4);
1019 EXPECT_EQ(x.getNumberOfSpectators(), 2);
1020 EXPECT_EQ(x.getNumberOfEvents(), 10);
1021
1022 // Should just work
1023 x.loadEvent(0);
1024 EXPECT_EQ(x.m_input.size(), 4);
1025 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1026 EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1027 EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1028 EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1029 EXPECT_EQ(x.m_spectators.size(), 2);
1030 EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1031 EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1032 EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1033 EXPECT_EQ(x.m_target, true);
1034 EXPECT_EQ(x.m_isSignal, true);
1035
1036 x.loadEvent(5);
1037 EXPECT_EQ(x.m_input.size(), 4);
1038 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1039 EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1040 EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1041 EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1042 EXPECT_EQ(x.m_spectators.size(), 2);
1043 EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1044 EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1045 EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1046 EXPECT_EQ(x.m_target, true);
1047 EXPECT_EQ(x.m_isSignal, true);
1048
1049 x.loadEvent(1);
1050 EXPECT_EQ(x.m_input.size(), 4);
1051 EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1052 EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1053 EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1054 EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1055 EXPECT_EQ(x.m_spectators.size(), 2);
1056 EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1057 EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1058 EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1059 EXPECT_EQ(x.m_target, false);
1060 EXPECT_EQ(x.m_isSignal, false);
1061
1062 x.loadEvent(6);
1063 EXPECT_EQ(x.m_input.size(), 4);
1064 EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1065 EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1066 EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1067 EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1068 EXPECT_EQ(x.m_spectators.size(), 2);
1069 EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1070 EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1071 EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1072 EXPECT_EQ(x.m_target, false);
1073 EXPECT_EQ(x.m_isSignal, false);
1074
1075 x.loadEvent(2);
1076 EXPECT_EQ(x.m_input.size(), 4);
1077 EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1078 EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1079 EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1080 EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1081 EXPECT_EQ(x.m_spectators.size(), 2);
1082 EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1083 EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1084 EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1085 EXPECT_EQ(x.m_target, true);
1086 EXPECT_EQ(x.m_isSignal, true);
1087
1088 x.loadEvent(7);
1089 EXPECT_EQ(x.m_input.size(), 4);
1090 EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1091 EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1092 EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1093 EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1094 EXPECT_EQ(x.m_spectators.size(), 2);
1095 EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1096 EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1097 EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1098 EXPECT_EQ(x.m_target, true);
1099 EXPECT_EQ(x.m_isSignal, true);
1100
1101 x.loadEvent(3);
1102 EXPECT_EQ(x.m_input.size(), 4);
1103 EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1104 EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1105 EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1106 EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1107 EXPECT_EQ(x.m_spectators.size(), 2);
1108 EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1109 EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1110 EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1111 EXPECT_EQ(x.m_target, false);
1112 EXPECT_EQ(x.m_isSignal, false);
1113
1114 x.loadEvent(8);
1115 EXPECT_EQ(x.m_input.size(), 4);
1116 EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1117 EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1118 EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1119 EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1120 EXPECT_EQ(x.m_spectators.size(), 2);
1121 EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1122 EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1123 EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1124 EXPECT_EQ(x.m_target, false);
1125 EXPECT_EQ(x.m_isSignal, false);
1126
1127 x.loadEvent(4);
1128 EXPECT_EQ(x.m_input.size(), 4);
1129 EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1130 EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1131 EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1132 EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1133 EXPECT_EQ(x.m_spectators.size(), 2);
1134 EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1135 EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1136 EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1137 EXPECT_EQ(x.m_target, true);
1138 EXPECT_EQ(x.m_isSignal, true);
1139
1140 x.loadEvent(9);
1141 EXPECT_EQ(x.m_input.size(), 4);
1142 EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1143 EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1144 EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1145 EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1146 EXPECT_EQ(x.m_spectators.size(), 2);
1147 EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1148 EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1149 EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1150 EXPECT_EQ(x.m_target, true);
1151 EXPECT_EQ(x.m_isSignal, true);
1152
1153 EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1154
1155 auto feature = x.getFeature(1);
1156 EXPECT_EQ(feature.size(), 10);
1157 EXPECT_FLOAT_EQ(feature[0], 1.1);
1158 EXPECT_FLOAT_EQ(feature[1], 2.1);
1159 EXPECT_FLOAT_EQ(feature[2], 3.1);
1160 EXPECT_FLOAT_EQ(feature[3], 4.1);
1161 EXPECT_FLOAT_EQ(feature[4], 5.1);
1162 EXPECT_FLOAT_EQ(feature[5], 1.1);
1163 EXPECT_FLOAT_EQ(feature[6], 2.1);
1164 EXPECT_FLOAT_EQ(feature[7], 3.1);
1165 EXPECT_FLOAT_EQ(feature[8], 4.1);
1166 EXPECT_FLOAT_EQ(feature[9], 5.1);
1167
1168 // Same result for mother class implementation
1169 feature = x.Dataset::getFeature(1);
1170 EXPECT_EQ(feature.size(), 10);
1171 EXPECT_FLOAT_EQ(feature[0], 1.1);
1172 EXPECT_FLOAT_EQ(feature[1], 2.1);
1173 EXPECT_FLOAT_EQ(feature[2], 3.1);
1174 EXPECT_FLOAT_EQ(feature[3], 4.1);
1175 EXPECT_FLOAT_EQ(feature[4], 5.1);
1176 EXPECT_FLOAT_EQ(feature[5], 1.1);
1177 EXPECT_FLOAT_EQ(feature[6], 2.1);
1178 EXPECT_FLOAT_EQ(feature[7], 3.1);
1179 EXPECT_FLOAT_EQ(feature[8], 4.1);
1180 EXPECT_FLOAT_EQ(feature[9], 5.1);
1181
1182 auto spectator = x.getSpectator(1);
1183 EXPECT_EQ(spectator.size(), 10);
1184 EXPECT_FLOAT_EQ(spectator[0], 1.7);
1185 EXPECT_FLOAT_EQ(spectator[1], 2.7);
1186 EXPECT_FLOAT_EQ(spectator[2], 3.7);
1187 EXPECT_FLOAT_EQ(spectator[3], 4.7);
1188 EXPECT_FLOAT_EQ(spectator[4], 5.7);
1189 EXPECT_FLOAT_EQ(spectator[5], 1.7);
1190 EXPECT_FLOAT_EQ(spectator[6], 2.7);
1191 EXPECT_FLOAT_EQ(spectator[7], 3.7);
1192 EXPECT_FLOAT_EQ(spectator[8], 4.7);
1193 EXPECT_FLOAT_EQ(spectator[9], 5.7);
1194
1195 // Same result for mother class implementation
1196 spectator = x.Dataset::getSpectator(1);
1197 EXPECT_EQ(spectator.size(), 10);
1198 EXPECT_FLOAT_EQ(spectator[0], 1.7);
1199 EXPECT_FLOAT_EQ(spectator[1], 2.7);
1200 EXPECT_FLOAT_EQ(spectator[2], 3.7);
1201 EXPECT_FLOAT_EQ(spectator[3], 4.7);
1202 EXPECT_FLOAT_EQ(spectator[4], 5.7);
1203 EXPECT_FLOAT_EQ(spectator[5], 1.7);
1204 EXPECT_FLOAT_EQ(spectator[6], 2.7);
1205 EXPECT_FLOAT_EQ(spectator[7], 3.7);
1206 EXPECT_FLOAT_EQ(spectator[8], 4.7);
1207 EXPECT_FLOAT_EQ(spectator[9], 5.7);
1208
1209 auto weights = x.getWeights();
1210 EXPECT_EQ(weights.size(), 10);
1211 EXPECT_FLOAT_EQ(weights[0], 1.2);
1212 EXPECT_FLOAT_EQ(weights[1], 2.2);
1213 EXPECT_FLOAT_EQ(weights[2], 3.2);
1214 EXPECT_FLOAT_EQ(weights[3], 4.2);
1215 EXPECT_FLOAT_EQ(weights[4], 5.2);
1216 EXPECT_FLOAT_EQ(weights[5], 1.2);
1217 EXPECT_FLOAT_EQ(weights[6], 2.2);
1218 EXPECT_FLOAT_EQ(weights[7], 3.2);
1219 EXPECT_FLOAT_EQ(weights[8], 4.2);
1220 EXPECT_FLOAT_EQ(weights[9], 5.2);
1221
1222 // Same result for mother class implementation
1223 weights = x.Dataset::getWeights();
1224 EXPECT_EQ(weights.size(), 10);
1225 EXPECT_FLOAT_EQ(weights[0], 1.2);
1226 EXPECT_FLOAT_EQ(weights[1], 2.2);
1227 EXPECT_FLOAT_EQ(weights[2], 3.2);
1228 EXPECT_FLOAT_EQ(weights[3], 4.2);
1229 EXPECT_FLOAT_EQ(weights[4], 5.2);
1230 EXPECT_FLOAT_EQ(weights[5], 1.2);
1231 EXPECT_FLOAT_EQ(weights[6], 2.2);
1232 EXPECT_FLOAT_EQ(weights[7], 3.2);
1233 EXPECT_FLOAT_EQ(weights[8], 4.2);
1234 EXPECT_FLOAT_EQ(weights[9], 5.2);
1235
1236 auto targets = x.getTargets();
1237 EXPECT_EQ(targets.size(), 10);
1238 EXPECT_EQ(targets[0], true);
1239 EXPECT_EQ(targets[1], false);
1240 EXPECT_EQ(targets[2], true);
1241 EXPECT_EQ(targets[3], false);
1242 EXPECT_EQ(targets[4], true);
1243 EXPECT_EQ(targets[5], true);
1244 EXPECT_EQ(targets[6], false);
1245 EXPECT_EQ(targets[7], true);
1246 EXPECT_EQ(targets[8], false);
1247 EXPECT_EQ(targets[9], true);
1248
1249 auto signals = x.getSignals();
1250 EXPECT_EQ(signals.size(), 10);
1251 EXPECT_EQ(signals[0], true);
1252 EXPECT_EQ(signals[1], false);
1253 EXPECT_EQ(signals[2], true);
1254 EXPECT_EQ(signals[3], false);
1255 EXPECT_EQ(signals[4], true);
1256 EXPECT_EQ(signals[5], true);
1257 EXPECT_EQ(signals[6], false);
1258 EXPECT_EQ(signals[7], true);
1259 EXPECT_EQ(signals[8], false);
1260 EXPECT_EQ(signals[9], true);
1261
1262 // Using __weight__ should work as well,
1263 // the only difference to using _weight__ instead of c is
1264 // in setBranchAddresses which avoids calling makeROOTCompatible
1265 // So we have to check the behaviour using __weight__ as well
1266 general_options.m_weight_variable = "__weight__";
1267 MVA::ROOTDataset y(general_options);
1268
1269 weights = y.getWeights();
1270 EXPECT_EQ(weights.size(), 10);
1271 EXPECT_FLOAT_EQ(weights[0], 1.2);
1272 EXPECT_FLOAT_EQ(weights[1], 2.2);
1273 EXPECT_FLOAT_EQ(weights[2], 3.2);
1274 EXPECT_FLOAT_EQ(weights[3], 4.2);
1275 EXPECT_FLOAT_EQ(weights[4], 5.2);
1276 EXPECT_FLOAT_EQ(weights[5], 1.2);
1277 EXPECT_FLOAT_EQ(weights[6], 2.2);
1278 EXPECT_FLOAT_EQ(weights[7], 3.2);
1279 EXPECT_FLOAT_EQ(weights[8], 4.2);
1280 EXPECT_FLOAT_EQ(weights[9], 5.2);
1281
1282 // Check TChain expansion
1283 general_options.m_datafiles = {"datafile*.root"};
1284 {
1285 MVA::ROOTDataset chain_test(general_options);
1286 EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1287 }
1288 std::filesystem::copy_file("datafile.root", "datafile3.root");
1289 {
1290 MVA::ROOTDataset chain_test(general_options);
1291 EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1292 }
1293 std::filesystem::copy_file("datafile.root", "datafile4.root");
1294 {
1295 MVA::ROOTDataset chain_test(general_options);
1296 EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1297 }
1298 // Test m_max_events feature
1299 {
1300 general_options.m_max_events = 10;
1301 MVA::ROOTDataset chain_test(general_options);
1302 EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1303 general_options.m_max_events = 0;
1304 }
1305
1306 // If a file exists with the specified expansion
1307 // the file takes precedence over the expansion
1308 std::filesystem::copy_file("datafile.root", "datafile*.root");
1309 {
1310 general_options.m_max_events = 0;
1311 MVA::ROOTDataset chain_test(general_options);
1312 EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1313 }
1314
1315 }
1316 TEST(DatasetTest, ROOTMultiDatasetDouble)
1317 {
1318
1320 TFile file("datafile.root", "RECREATE");
1321 file.cd();
1322 TTree tree("tree", "TreeTitle");
1323 double a, b, c, d, e, f, g, v, w = 0;
1324 tree.Branch("a", &a, "a/D");
1325 tree.Branch("b", &b, "b/D");
1326 tree.Branch("c", &c, "c/D");
1327 tree.Branch("d", &d, "d/D");
1328 tree.Branch("e__bo__bc", &e, "e__bo__bc/D");
1329 tree.Branch("f__bo__bc", &f, "f__bo__bc/D");
1330 tree.Branch("g", &g, "g/D");
1331 tree.Branch("__weight__", &c, "__weight__/D");
1332 tree.Branch("v__bo__bc", &v, "v__bo__bc/D");
1333 tree.Branch("w", &w, "w/D");
1334
1335 for (unsigned int i = 0; i < 5; ++i) {
1336 a = i + 1.0;
1337 b = i + 1.1;
1338 c = i + 1.2;
1339 d = i + 1.3;
1340 e = i + 1.4;
1341 f = i + 1.5;
1342 g = float(i % 2 == 0);
1343 w = i + 1.6;
1344 v = i + 1.7;
1345 tree.Fill();
1346 }
1347
1348 file.Write("tree");
1349
1350 TFile file2("datafile2.root", "RECREATE");
1351 file2.cd();
1352 TTree tree2("tree", "TreeTitle");
1353 tree2.Branch("a", &a);
1354 tree2.Branch("b", &b);
1355 tree2.Branch("c", &c);
1356 tree2.Branch("d", &d);
1357 tree2.Branch("e__bo__bc", &e);
1358 tree2.Branch("f__bo__bc", &f);
1359 tree2.Branch("g", &g);
1360 tree2.Branch("__weight__", &c);
1361 tree2.Branch("v__bo__bc", &v);
1362 tree2.Branch("w", &w);
1363
1364 for (unsigned int i = 0; i < 5; ++i) {
1365 a = i + 1.0;
1366 b = i + 1.1;
1367 c = i + 1.2;
1368 d = i + 1.3;
1369 e = i + 1.4;
1370 f = i + 1.5;
1371 g = float(i % 2 == 0);
1372 w = i + 1.6;
1373 v = i + 1.7;
1374 tree2.Fill();
1375 }
1376
1377 file2.Write("tree");
1378
1379 MVA::GeneralOptions general_options;
1380 // Both names with and without makeROOTCompatible should work
1381 general_options.m_variables = {"a", "b", "e__bo__bc", "f()"};
1382 general_options.m_spectators = {"w", "v()"};
1383 general_options.m_signal_class = 1;
1384 general_options.m_datafiles = {"datafile.root", "datafile2.root"};
1385 general_options.m_treename = "tree";
1386 general_options.m_target_variable = "g";
1387 general_options.m_weight_variable = "c";
1388 MVA::ROOTDataset x(general_options);
1389
1390 EXPECT_EQ(x.getNumberOfFeatures(), 4);
1391 EXPECT_EQ(x.getNumberOfSpectators(), 2);
1392 EXPECT_EQ(x.getNumberOfEvents(), 10);
1393
1394 // Should just work
1395 x.loadEvent(0);
1396 EXPECT_EQ(x.m_input.size(), 4);
1397 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1398 EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1399 EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1400 EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1401 EXPECT_EQ(x.m_spectators.size(), 2);
1402 EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1403 EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1404 EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1405 EXPECT_FLOAT_EQ(x.m_target, 1.0);
1406 EXPECT_EQ(x.m_isSignal, true);
1407
1408 x.loadEvent(5);
1409 EXPECT_EQ(x.m_input.size(), 4);
1410 EXPECT_FLOAT_EQ(x.m_input[0], 1.0);
1411 EXPECT_FLOAT_EQ(x.m_input[1], 1.1);
1412 EXPECT_FLOAT_EQ(x.m_input[2], 1.4);
1413 EXPECT_FLOAT_EQ(x.m_input[3], 1.5);
1414 EXPECT_EQ(x.m_spectators.size(), 2);
1415 EXPECT_FLOAT_EQ(x.m_spectators[0], 1.6);
1416 EXPECT_FLOAT_EQ(x.m_spectators[1], 1.7);
1417 EXPECT_FLOAT_EQ(x.m_weight, 1.2);
1418 EXPECT_FLOAT_EQ(x.m_target, 1.0);
1419 EXPECT_EQ(x.m_isSignal, true);
1420
1421 x.loadEvent(1);
1422 EXPECT_EQ(x.m_input.size(), 4);
1423 EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1424 EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1425 EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1426 EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1427 EXPECT_EQ(x.m_spectators.size(), 2);
1428 EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1429 EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1430 EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1431 EXPECT_FLOAT_EQ(x.m_target, 0.0);
1432 EXPECT_EQ(x.m_isSignal, false);
1433
1434 x.loadEvent(6);
1435 EXPECT_EQ(x.m_input.size(), 4);
1436 EXPECT_FLOAT_EQ(x.m_input[0], 2.0);
1437 EXPECT_FLOAT_EQ(x.m_input[1], 2.1);
1438 EXPECT_FLOAT_EQ(x.m_input[2], 2.4);
1439 EXPECT_FLOAT_EQ(x.m_input[3], 2.5);
1440 EXPECT_EQ(x.m_spectators.size(), 2);
1441 EXPECT_FLOAT_EQ(x.m_spectators[0], 2.6);
1442 EXPECT_FLOAT_EQ(x.m_spectators[1], 2.7);
1443 EXPECT_FLOAT_EQ(x.m_weight, 2.2);
1444 EXPECT_FLOAT_EQ(x.m_target, 0.0);
1445 EXPECT_EQ(x.m_isSignal, false);
1446
1447 x.loadEvent(2);
1448 EXPECT_EQ(x.m_input.size(), 4);
1449 EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1450 EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1451 EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1452 EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1453 EXPECT_EQ(x.m_spectators.size(), 2);
1454 EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1455 EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1456 EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1457 EXPECT_FLOAT_EQ(x.m_target, 1.0);
1458 EXPECT_EQ(x.m_isSignal, true);
1459
1460 x.loadEvent(7);
1461 EXPECT_EQ(x.m_input.size(), 4);
1462 EXPECT_FLOAT_EQ(x.m_input[0], 3.0);
1463 EXPECT_FLOAT_EQ(x.m_input[1], 3.1);
1464 EXPECT_FLOAT_EQ(x.m_input[2], 3.4);
1465 EXPECT_FLOAT_EQ(x.m_input[3], 3.5);
1466 EXPECT_EQ(x.m_spectators.size(), 2);
1467 EXPECT_FLOAT_EQ(x.m_spectators[0], 3.6);
1468 EXPECT_FLOAT_EQ(x.m_spectators[1], 3.7);
1469 EXPECT_FLOAT_EQ(x.m_weight, 3.2);
1470 EXPECT_FLOAT_EQ(x.m_target, 1.0);
1471 EXPECT_EQ(x.m_isSignal, true);
1472
1473 x.loadEvent(3);
1474 EXPECT_EQ(x.m_input.size(), 4);
1475 EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1476 EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1477 EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1478 EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1479 EXPECT_EQ(x.m_spectators.size(), 2);
1480 EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1481 EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1482 EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1483 EXPECT_FLOAT_EQ(x.m_target, 0.0);
1484 EXPECT_EQ(x.m_isSignal, false);
1485
1486 x.loadEvent(8);
1487 EXPECT_EQ(x.m_input.size(), 4);
1488 EXPECT_FLOAT_EQ(x.m_input[0], 4.0);
1489 EXPECT_FLOAT_EQ(x.m_input[1], 4.1);
1490 EXPECT_FLOAT_EQ(x.m_input[2], 4.4);
1491 EXPECT_FLOAT_EQ(x.m_input[3], 4.5);
1492 EXPECT_EQ(x.m_spectators.size(), 2);
1493 EXPECT_FLOAT_EQ(x.m_spectators[0], 4.6);
1494 EXPECT_FLOAT_EQ(x.m_spectators[1], 4.7);
1495 EXPECT_FLOAT_EQ(x.m_weight, 4.2);
1496 EXPECT_FLOAT_EQ(x.m_target, 0.0);
1497 EXPECT_EQ(x.m_isSignal, false);
1498
1499 x.loadEvent(4);
1500 EXPECT_EQ(x.m_input.size(), 4);
1501 EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1502 EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1503 EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1504 EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1505 EXPECT_EQ(x.m_spectators.size(), 2);
1506 EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1507 EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1508 EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1509 EXPECT_FLOAT_EQ(x.m_target, 1.0);
1510 EXPECT_EQ(x.m_isSignal, true);
1511
1512 x.loadEvent(9);
1513 EXPECT_EQ(x.m_input.size(), 4);
1514 EXPECT_FLOAT_EQ(x.m_input[0], 5.0);
1515 EXPECT_FLOAT_EQ(x.m_input[1], 5.1);
1516 EXPECT_FLOAT_EQ(x.m_input[2], 5.4);
1517 EXPECT_FLOAT_EQ(x.m_input[3], 5.5);
1518 EXPECT_EQ(x.m_spectators.size(), 2);
1519 EXPECT_FLOAT_EQ(x.m_spectators[0], 5.6);
1520 EXPECT_FLOAT_EQ(x.m_spectators[1], 5.7);
1521 EXPECT_FLOAT_EQ(x.m_weight, 5.2);
1522 EXPECT_FLOAT_EQ(x.m_target, 1.0);
1523 EXPECT_EQ(x.m_isSignal, true);
1524
1525 EXPECT_FLOAT_EQ(x.getSignalFraction(), 0.6);
1526
1527 auto feature = x.getFeature(1);
1528 EXPECT_EQ(feature.size(), 10);
1529 EXPECT_FLOAT_EQ(feature[0], 1.1);
1530 EXPECT_FLOAT_EQ(feature[1], 2.1);
1531 EXPECT_FLOAT_EQ(feature[2], 3.1);
1532 EXPECT_FLOAT_EQ(feature[3], 4.1);
1533 EXPECT_FLOAT_EQ(feature[4], 5.1);
1534 EXPECT_FLOAT_EQ(feature[5], 1.1);
1535 EXPECT_FLOAT_EQ(feature[6], 2.1);
1536 EXPECT_FLOAT_EQ(feature[7], 3.1);
1537 EXPECT_FLOAT_EQ(feature[8], 4.1);
1538 EXPECT_FLOAT_EQ(feature[9], 5.1);
1539
1540 // Same result for mother class implementation
1541 feature = x.Dataset::getFeature(1);
1542 EXPECT_EQ(feature.size(), 10);
1543 EXPECT_FLOAT_EQ(feature[0], 1.1);
1544 EXPECT_FLOAT_EQ(feature[1], 2.1);
1545 EXPECT_FLOAT_EQ(feature[2], 3.1);
1546 EXPECT_FLOAT_EQ(feature[3], 4.1);
1547 EXPECT_FLOAT_EQ(feature[4], 5.1);
1548 EXPECT_FLOAT_EQ(feature[5], 1.1);
1549 EXPECT_FLOAT_EQ(feature[6], 2.1);
1550 EXPECT_FLOAT_EQ(feature[7], 3.1);
1551 EXPECT_FLOAT_EQ(feature[8], 4.1);
1552 EXPECT_FLOAT_EQ(feature[9], 5.1);
1553
1554 auto spectator = x.getSpectator(1);
1555 EXPECT_EQ(spectator.size(), 10);
1556 EXPECT_FLOAT_EQ(spectator[0], 1.7);
1557 EXPECT_FLOAT_EQ(spectator[1], 2.7);
1558 EXPECT_FLOAT_EQ(spectator[2], 3.7);
1559 EXPECT_FLOAT_EQ(spectator[3], 4.7);
1560 EXPECT_FLOAT_EQ(spectator[4], 5.7);
1561 EXPECT_FLOAT_EQ(spectator[5], 1.7);
1562 EXPECT_FLOAT_EQ(spectator[6], 2.7);
1563 EXPECT_FLOAT_EQ(spectator[7], 3.7);
1564 EXPECT_FLOAT_EQ(spectator[8], 4.7);
1565 EXPECT_FLOAT_EQ(spectator[9], 5.7);
1566
1567 // Same result for mother class implementation
1568 spectator = x.Dataset::getSpectator(1);
1569 EXPECT_EQ(spectator.size(), 10);
1570 EXPECT_FLOAT_EQ(spectator[0], 1.7);
1571 EXPECT_FLOAT_EQ(spectator[1], 2.7);
1572 EXPECT_FLOAT_EQ(spectator[2], 3.7);
1573 EXPECT_FLOAT_EQ(spectator[3], 4.7);
1574 EXPECT_FLOAT_EQ(spectator[4], 5.7);
1575 EXPECT_FLOAT_EQ(spectator[5], 1.7);
1576 EXPECT_FLOAT_EQ(spectator[6], 2.7);
1577 EXPECT_FLOAT_EQ(spectator[7], 3.7);
1578 EXPECT_FLOAT_EQ(spectator[8], 4.7);
1579 EXPECT_FLOAT_EQ(spectator[9], 5.7);
1580
1581 auto weights = x.getWeights();
1582 EXPECT_EQ(weights.size(), 10);
1583 EXPECT_FLOAT_EQ(weights[0], 1.2);
1584 EXPECT_FLOAT_EQ(weights[1], 2.2);
1585 EXPECT_FLOAT_EQ(weights[2], 3.2);
1586 EXPECT_FLOAT_EQ(weights[3], 4.2);
1587 EXPECT_FLOAT_EQ(weights[4], 5.2);
1588 EXPECT_FLOAT_EQ(weights[5], 1.2);
1589 EXPECT_FLOAT_EQ(weights[6], 2.2);
1590 EXPECT_FLOAT_EQ(weights[7], 3.2);
1591 EXPECT_FLOAT_EQ(weights[8], 4.2);
1592 EXPECT_FLOAT_EQ(weights[9], 5.2);
1593
1594 // Same result for mother class implementation
1595 weights = x.Dataset::getWeights();
1596 EXPECT_EQ(weights.size(), 10);
1597 EXPECT_FLOAT_EQ(weights[0], 1.2);
1598 EXPECT_FLOAT_EQ(weights[1], 2.2);
1599 EXPECT_FLOAT_EQ(weights[2], 3.2);
1600 EXPECT_FLOAT_EQ(weights[3], 4.2);
1601 EXPECT_FLOAT_EQ(weights[4], 5.2);
1602 EXPECT_FLOAT_EQ(weights[5], 1.2);
1603 EXPECT_FLOAT_EQ(weights[6], 2.2);
1604 EXPECT_FLOAT_EQ(weights[7], 3.2);
1605 EXPECT_FLOAT_EQ(weights[8], 4.2);
1606 EXPECT_FLOAT_EQ(weights[9], 5.2);
1607
1608 auto targets = x.getTargets();
1609 EXPECT_EQ(targets.size(), 10);
1610 EXPECT_FLOAT_EQ(targets[0], 1.0);
1611 EXPECT_FLOAT_EQ(targets[1], 0.0);
1612 EXPECT_FLOAT_EQ(targets[2], 1.0);
1613 EXPECT_FLOAT_EQ(targets[3], 0.0);
1614 EXPECT_FLOAT_EQ(targets[4], 1.0);
1615 EXPECT_FLOAT_EQ(targets[5], 1.0);
1616 EXPECT_FLOAT_EQ(targets[6], 0.0);
1617 EXPECT_FLOAT_EQ(targets[7], 1.0);
1618 EXPECT_FLOAT_EQ(targets[8], 0.0);
1619 EXPECT_FLOAT_EQ(targets[9], 1.0);
1620
1621 auto signals = x.getSignals();
1622 EXPECT_EQ(signals.size(), 10);
1623 EXPECT_EQ(signals[0], true);
1624 EXPECT_EQ(signals[1], false);
1625 EXPECT_EQ(signals[2], true);
1626 EXPECT_EQ(signals[3], false);
1627 EXPECT_EQ(signals[4], true);
1628 EXPECT_EQ(signals[5], true);
1629 EXPECT_EQ(signals[6], false);
1630 EXPECT_EQ(signals[7], true);
1631 EXPECT_EQ(signals[8], false);
1632 EXPECT_EQ(signals[9], true);
1633
1634 // Using __weight__ should work as well,
1635 // the only difference to using _weight__ instead of g is
1636 // in setBranchAddresses which avoids calling makeROOTCompatible
1637 // So we have to check the behaviour using __weight__ as well
1638 general_options.m_weight_variable = "__weight__";
1639 MVA::ROOTDataset y(general_options);
1640
1641 weights = y.getWeights();
1642 EXPECT_EQ(weights.size(), 10);
1643 EXPECT_FLOAT_EQ(weights[0], 1.2);
1644 EXPECT_FLOAT_EQ(weights[1], 2.2);
1645 EXPECT_FLOAT_EQ(weights[2], 3.2);
1646 EXPECT_FLOAT_EQ(weights[3], 4.2);
1647 EXPECT_FLOAT_EQ(weights[4], 5.2);
1648 EXPECT_FLOAT_EQ(weights[5], 1.2);
1649 EXPECT_FLOAT_EQ(weights[6], 2.2);
1650 EXPECT_FLOAT_EQ(weights[7], 3.2);
1651 EXPECT_FLOAT_EQ(weights[8], 4.2);
1652 EXPECT_FLOAT_EQ(weights[9], 5.2);
1653
1654 // Check TChain expansion
1655 general_options.m_datafiles = {"datafile*.root"};
1656 {
1657 MVA::ROOTDataset chain_test(general_options);
1658 EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1659 }
1660 std::filesystem::copy_file("datafile.root", "datafile3.root");
1661 {
1662 MVA::ROOTDataset chain_test(general_options);
1663 EXPECT_EQ(chain_test.getNumberOfEvents(), 15);
1664 }
1665 std::filesystem::copy_file("datafile.root", "datafile4.root");
1666 {
1667 MVA::ROOTDataset chain_test(general_options);
1668 EXPECT_EQ(chain_test.getNumberOfEvents(), 20);
1669 }
1670 // Test m_max_events feature
1671 {
1672 general_options.m_max_events = 10;
1673 MVA::ROOTDataset chain_test(general_options);
1674 EXPECT_EQ(chain_test.getNumberOfEvents(), 10);
1675 general_options.m_max_events = 0;
1676 }
1677
1678 // If a file exists with the specified expansion
1679 // the file takes precedence over the expansion
1680 std::filesystem::copy_file("datafile.root", "datafile*.root");
1681 {
1682 general_options.m_max_events = 0;
1683 MVA::ROOTDataset chain_test(general_options);
1684 EXPECT_EQ(chain_test.getNumberOfEvents(), 5);
1685 }
1686 }
1687}
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Definition: Dataset.h:294
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual std::vector< float > getSpectator(unsigned int iSpectator)
Returns all values of one spectator in a std::vector<float>
Definition: Dataset.cc:86
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:74
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:35
General options which are shared by all MVA trainings.
Definition: Options.h:62
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
Definition: Dataset.h:349
Wraps the data of a single event into a Dataset.
Definition: Dataset.h:135
Wraps another Dataset and provides a view to a subset of its features and events.
Definition: Dataset.h:234
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:66
Abstract base class for different kinds of events.