Belle II Software development
test_ONNX.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/interface/Interface.h>
10#include <mva/methods/ONNX.h>
11#include <framework/utilities/FileSystem.h>
12
13#include <gtest/gtest.h>
14
15using namespace Belle2::MVA;
16
17namespace {
22 TEST(ONNXTest, ONNXExpert)
23 {
25 auto expert = interface.getExpert();
26 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX.xml"));
27 expert->load(weightfile);
28 GeneralOptions general_options;
29 weightfile.getOptions(general_options);
30 MultiDataset dataset(general_options, {
31 {
32 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
33 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
34 },
35 {
36 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
37 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
38 }
39 },
40 {}, {0.0, 1.0});
41 auto probabilities = expert->apply(dataset);
42 EXPECT_NEAR(probabilities[0], 0.3328, 0.0001);
43 EXPECT_NEAR(probabilities[1], 0.5779, 0.0001);
44 }
45
46 TEST(ONNXTest, ONNXExpertMulticlass)
47 {
49 auto expert = interface.getExpert();
50 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX_multiclass.xml"));
51 expert->load(weightfile);
52 GeneralOptions general_options;
53 weightfile.getOptions(general_options);
54 MultiDataset dataset(general_options, {
55 {
56 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
57 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
58 },
59 {
60 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
61 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
62 }
63 },
64 {}, {0.0, 1.0});
65 auto probabilities = expert->applyMulticlass(dataset);
66 EXPECT_NEAR(probabilities[0][0], 0.3331, 0.0001);
67 EXPECT_NEAR(probabilities[0][1], -0.5373, 0.0001);
68 EXPECT_NEAR(probabilities[1][0], 0.5782, 0.0001);
69 EXPECT_NEAR(probabilities[1][1], -0.3697, 0.0001);
70 }
71
72 TEST(ONNXTest, ONNXExpertMulticlassThreeClasses)
73 {
75 auto expert = interface.getExpert();
76 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX_multiclass_3.xml"));
77 expert->load(weightfile);
78 GeneralOptions general_options;
79 weightfile.getOptions(general_options);
80 MultiDataset dataset(general_options, {
81 {
82 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
83 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
84 },
85 {
86 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
87 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
88 }
89 },
90 {}, {0.0, 1.0});
91 auto probabilities = expert->applyMulticlass(dataset);
92 EXPECT_NEAR(probabilities[0][0], -0.5394, 0.0001);
93 EXPECT_NEAR(probabilities[0][1], 0.0529, 0.0001);
94 EXPECT_NEAR(probabilities[0][2], -0.0598, 0.0001);
95 EXPECT_NEAR(probabilities[1][0], -0.5089, 0.0001);
96 EXPECT_NEAR(probabilities[1][1], 0.0060, 0.0001);
97 EXPECT_NEAR(probabilities[1][2], 0.0132, 0.0001);
98 }
99}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
General options which are shared by all MVA trainings.
Definition Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition Interface.h:125
Wraps the data of a multiple event into a Dataset.
Definition Dataset.h:186
Expert for the ONNX MVA method.
Definition ONNX.h:149
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.