Belle II Software light-2509-fornax
test_ONNX.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/interface/Interface.h>
10#include <mva/methods/ONNX.h>
11#include <framework/utilities/FileSystem.h>
12#include <framework/utilities/TestHelpers.h>
13
14#include <gtest/gtest.h>
15
16using namespace Belle2::MVA;
17
18namespace {
20
25 TEST(ONNXTest, ONNXExpert)
26 {
27 auto expert = interface.getExpert();
28 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX.xml"));
29 expert->load(weightfile);
30 GeneralOptions general_options;
31 weightfile.getOptions(general_options);
32 MultiDataset dataset(general_options, {
33 {
34 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
35 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
36 },
37 {
38 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
39 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
40 }
41 },
42 {}, {0.0, 1.0});
43 auto probabilities = expert->apply(dataset);
44 EXPECT_NEAR(probabilities[0], 0.3328, 0.0001);
45 EXPECT_NEAR(probabilities[1], 0.5779, 0.0001);
46 }
47
48 TEST(ONNXTest, ONNXExpertMulticlass)
49 {
50 auto expert = interface.getExpert();
51 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX_multiclass.xml"));
52 expert->load(weightfile);
53 GeneralOptions general_options;
54 weightfile.getOptions(general_options);
55 MultiDataset dataset(general_options, {
56 {
57 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
58 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
59 },
60 {
61 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
62 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
63 }
64 },
65 {}, {0.0, 1.0});
66 auto probabilities = expert->applyMulticlass(dataset);
67 EXPECT_NEAR(probabilities[0][0], 0.3331, 0.0001);
68 EXPECT_NEAR(probabilities[0][1], -0.5373, 0.0001);
69 EXPECT_NEAR(probabilities[1][0], 0.5782, 0.0001);
70 EXPECT_NEAR(probabilities[1][1], -0.3697, 0.0001);
71 }
72
73 TEST(ONNXTest, ONNXExpertTwoClassUseSingleOutput)
74 {
75 auto expert = interface.getExpert();
76 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX_multiclass.xml"));
77 expert->load(weightfile);
78 GeneralOptions general_options;
79 weightfile.getOptions(general_options);
80 MultiDataset dataset(general_options, {
81 {
82 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
83 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
84 },
85 {
86 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
87 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
88 }
89 },
90 {}, {0.0, 1.0});
91 // Running apply (not applyMulticlass) on this is supposed to assume a binary classifier and pick output index 1
92 auto probabilities = expert->apply(dataset);
93 EXPECT_NEAR(probabilities[0], -0.5373, 0.0001);
94 EXPECT_NEAR(probabilities[1], -0.3697, 0.0001);
95 }
96
97 TEST(ONNXTest, ONNXExpertMulticlassThreeClasses)
98 {
99 auto expert = interface.getExpert();
100 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX_multiclass_3.xml"));
101 expert->load(weightfile);
102 GeneralOptions general_options;
103 weightfile.getOptions(general_options);
104 MultiDataset dataset(general_options, {
105 {
106 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
107 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
108 },
109 {
110 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
111 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
112 }
113 },
114 {}, {0.0, 1.0});
115 auto probabilities = expert->applyMulticlass(dataset);
116 EXPECT_NEAR(probabilities[0][0], -0.5394, 0.0001);
117 EXPECT_NEAR(probabilities[0][1], 0.0529, 0.0001);
118 EXPECT_NEAR(probabilities[0][2], -0.0598, 0.0001);
119 EXPECT_NEAR(probabilities[1][0], -0.5089, 0.0001);
120 EXPECT_NEAR(probabilities[1][1], 0.0060, 0.0001);
121 EXPECT_NEAR(probabilities[1][2], 0.0132, 0.0001);
122 }
123
124 Weightfile getONNXWeightfile(std::string modelFilenameONNX, std::string outputName = "")
125 {
126 Weightfile weightfile;
127 GeneralOptions general_options;
128 ONNXOptions specific_options;
129 if (!outputName.empty()) {
130 specific_options.m_outputName = outputName;
131 }
132 general_options.m_method = specific_options.getMethod();
133 weightfile.addOptions(general_options);
134 weightfile.addOptions(specific_options);
135 weightfile.addFile("ONNX_Modelfile", modelFilenameONNX);
136 return weightfile;
137 }
138
139 TEST(ONNXTest, ONNXFatalMultipleInputs)
140 {
141 // The following modelfile has multiple inputs
142 // should fail when using with the ONNX MVA method
143 // (onnx file created with mva/examples/onnx/write_test_files.py)
144 auto weightfile = getONNXWeightfile(Belle2::FileSystem::findFile("mva/methods/tests/ModelABToAB.onnx"));
145 auto expert = interface.getExpert();
146 EXPECT_B2FATAL(expert->load(weightfile));
147 }
148
149 TEST(ONNXTest, ONNXFatalMultipleOutputs)
150 {
151 // The following modelfile has multiple outputs ("output_a", "output_twice_a"), so none of them named "output"
152 // should fail when using with the ONNX MVA method
153 // (onnx file created with mva/examples/onnx/write_test_files.py)
154 auto weightfile = getONNXWeightfile(Belle2::FileSystem::findFile("mva/methods/tests/ModelAToATwiceA.onnx"));
155 auto expert = interface.getExpert();
156 EXPECT_B2FATAL(expert->load(weightfile));
157 }
158
159 TEST(ONNXTest, ONNXMultipleOutputsOK)
160 {
161 // explicitly choosing to use "output_twice_a" should work
162 // (onnx file created with mva/examples/onnx/write_test_files.py)
163 auto weightfile = getONNXWeightfile(Belle2::FileSystem::findFile("mva/methods/tests/ModelAToATwiceA.onnx"), "output_twice_a");
164 auto expert = interface.getExpert();
165 expert->load(weightfile);
166 }
167}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
General options which are shared by all MVA trainings.
Definition Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition Interface.h:125
Wraps the data of a multiple event into a Dataset.
Definition Dataset.h:186
Expert for the ONNX MVA method.
Definition ONNX.h:446
Options for the ONNX MVA method.
Definition ONNX.h:374
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.