Belle II Software development
test_ONNX.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/interface/Interface.h>
10#include <mva/methods/ONNX.h>
11#include <framework/utilities/FileSystem.h>
12#include <framework/utilities/TestHelpers.h>
13
14#include <gtest/gtest.h>
15
16using namespace Belle2::MVA;
17
18namespace {
20
25 TEST(ONNXTest, ONNXExpert)
26 {
27 auto expert = interface.getExpert();
28 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX.xml"));
29 expert->load(weightfile);
30 GeneralOptions general_options;
31 weightfile.getOptions(general_options);
32 MultiDataset dataset(general_options, {
33 {
34 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
35 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
36 },
37 {
38 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
39 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
40 }
41 },
42 {}, {0.0, 1.0});
43 auto probabilities = expert->apply(dataset);
44 EXPECT_NEAR(probabilities[0], 0.3328, 0.0001);
45 EXPECT_NEAR(probabilities[1], 0.5779, 0.0001);
46 }
47
48 TEST(ONNXTest, ONNXExpertMulticlass)
49 {
50 auto expert = interface.getExpert();
51 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX_multiclass.xml"));
52 expert->load(weightfile);
53 GeneralOptions general_options;
54 weightfile.getOptions(general_options);
55 MultiDataset dataset(general_options, {
56 {
57 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
58 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
59 },
60 {
61 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
62 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
63 }
64 },
65 {}, {0.0, 1.0});
66 auto probabilities = expert->applyMulticlass(dataset);
67 EXPECT_NEAR(probabilities[0][0], 0.3331, 0.0001);
68 EXPECT_NEAR(probabilities[0][1], -0.5373, 0.0001);
69 EXPECT_NEAR(probabilities[1][0], 0.5782, 0.0001);
70 EXPECT_NEAR(probabilities[1][1], -0.3697, 0.0001);
71 }
72
73 TEST(ONNXTest, ONNXExpertMulticlassThreeClasses)
74 {
75 auto expert = interface.getExpert();
76 auto weightfile = Weightfile::loadFromFile(Belle2::FileSystem::findFile("mva/methods/tests/ONNX_multiclass_3.xml"));
77 expert->load(weightfile);
78 GeneralOptions general_options;
79 weightfile.getOptions(general_options);
80 MultiDataset dataset(general_options, {
81 {
82 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
83 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
84 },
85 {
86 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
87 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
88 }
89 },
90 {}, {0.0, 1.0});
91 auto probabilities = expert->applyMulticlass(dataset);
92 EXPECT_NEAR(probabilities[0][0], -0.5394, 0.0001);
93 EXPECT_NEAR(probabilities[0][1], 0.0529, 0.0001);
94 EXPECT_NEAR(probabilities[0][2], -0.0598, 0.0001);
95 EXPECT_NEAR(probabilities[1][0], -0.5089, 0.0001);
96 EXPECT_NEAR(probabilities[1][1], 0.0060, 0.0001);
97 EXPECT_NEAR(probabilities[1][2], 0.0132, 0.0001);
98 }
99
100 Weightfile getONNXWeightfile(std::string modelFilenameONNX, std::string outputName = "")
101 {
102 Weightfile weightfile;
103 GeneralOptions general_options;
104 ONNXOptions specific_options;
105 if (!outputName.empty()) {
106 specific_options.m_outputName = outputName;
107 }
108 general_options.m_method = specific_options.getMethod();
109 weightfile.addOptions(general_options);
110 weightfile.addOptions(specific_options);
111 weightfile.addFile("ONNX_Modelfile", modelFilenameONNX);
112 return weightfile;
113 }
114
115 TEST(ONNXTest, ONNXFatalMultipleInputs)
116 {
117 // The following modelfile has multiple inputs
118 // should fail when using with the ONNX MVA method
119 // (onnx file created with mva/examples/onnx/write_test_files.py)
120 auto weightfile = getONNXWeightfile(Belle2::FileSystem::findFile("mva/methods/tests/ModelABToAB.onnx"));
121 auto expert = interface.getExpert();
122 EXPECT_B2FATAL(expert->load(weightfile));
123 }
124
125 TEST(ONNXTest, ONNXFatalMultipleOutputs)
126 {
127 // The following modelfile has multiple outputs ("output_a", "output_twice_a"), so none of them named "output"
128 // should fail when using with the ONNX MVA method
129 // (onnx file created with mva/examples/onnx/write_test_files.py)
130 auto weightfile = getONNXWeightfile(Belle2::FileSystem::findFile("mva/methods/tests/ModelAToATwiceA.onnx"));
131 auto expert = interface.getExpert();
132 EXPECT_B2FATAL(expert->load(weightfile));
133 }
134
135 TEST(ONNXTest, ONNXMultipleOutputsOK)
136 {
137 // explicitly choosing to use "output_twice_a" should work
138 // (onnx file created with mva/examples/onnx/write_test_files.py)
139 auto weightfile = getONNXWeightfile(Belle2::FileSystem::findFile("mva/methods/tests/ModelAToATwiceA.onnx"), "output_twice_a");
140 auto expert = interface.getExpert();
141 expert->load(weightfile);
142 }
143}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
General options which are shared by all MVA trainings.
Definition Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition Interface.h:125
Wraps the data of a multiple event into a Dataset.
Definition Dataset.h:186
Expert for the ONNX MVA method.
Definition ONNX.h:437
Options for the ONNX MVA method.
Definition ONNX.h:374
The Weightfile class serializes all information about a training into an xml tree.
Definition Weightfile.h:38
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition Weightfile.cc:62
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.