9#include <mva/interface/Interface.h>
10#include <mva/methods/ONNX.h>
11#include <framework/utilities/FileSystem.h>
12#include <framework/utilities/TestHelpers.h>
14#include <gtest/gtest.h>
16using namespace Belle2::MVA;
29 expert->load(weightfile);
31 weightfile.getOptions(general_options);
34 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
35 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
38 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
39 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
43 auto probabilities = expert->apply(dataset);
44 EXPECT_NEAR(probabilities[0], 0.3328, 0.0001);
45 EXPECT_NEAR(probabilities[1], 0.5779, 0.0001);
48 TEST(ONNXTest, ONNXExpertMulticlass)
52 expert->load(weightfile);
54 weightfile.getOptions(general_options);
57 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
58 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
61 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
62 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
66 auto probabilities = expert->applyMulticlass(dataset);
67 EXPECT_NEAR(probabilities[0][0], 0.3331, 0.0001);
68 EXPECT_NEAR(probabilities[0][1], -0.5373, 0.0001);
69 EXPECT_NEAR(probabilities[1][0], 0.5782, 0.0001);
70 EXPECT_NEAR(probabilities[1][1], -0.3697, 0.0001);
73 TEST(ONNXTest, ONNXExpertTwoClassUseSingleOutput)
77 expert->load(weightfile);
79 weightfile.getOptions(general_options);
82 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
83 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
86 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
87 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
92 auto probabilities = expert->apply(dataset);
93 EXPECT_NEAR(probabilities[0], -0.5373, 0.0001);
94 EXPECT_NEAR(probabilities[1], -0.3697, 0.0001);
97 TEST(ONNXTest, ONNXExpertMulticlassThreeClasses)
101 expert->load(weightfile);
103 weightfile.getOptions(general_options);
106 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
107 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
110 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
111 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
115 auto probabilities = expert->applyMulticlass(dataset);
116 EXPECT_NEAR(probabilities[0][0], -0.5394, 0.0001);
117 EXPECT_NEAR(probabilities[0][1], 0.0529, 0.0001);
118 EXPECT_NEAR(probabilities[0][2], -0.0598, 0.0001);
119 EXPECT_NEAR(probabilities[1][0], -0.5089, 0.0001);
120 EXPECT_NEAR(probabilities[1][1], 0.0060, 0.0001);
121 EXPECT_NEAR(probabilities[1][2], 0.0132, 0.0001);
124 Weightfile getONNXWeightfile(std::string modelFilenameONNX, std::string outputName =
"")
129 if (!outputName.empty()) {
130 specific_options.m_outputName = outputName;
132 general_options.m_method = specific_options.getMethod();
133 weightfile.addOptions(general_options);
134 weightfile.addOptions(specific_options);
135 weightfile.addFile(
"ONNX_Modelfile", modelFilenameONNX);
139 TEST(ONNXTest, ONNXFatalMultipleInputs)
146 EXPECT_B2FATAL(expert->load(weightfile));
149 TEST(ONNXTest, ONNXFatalMultipleOutputs)
156 EXPECT_B2FATAL(expert->load(weightfile));
159 TEST(ONNXTest, ONNXMultipleOutputsOK)
165 expert->load(weightfile);
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Wraps the data of a multiple event into a Dataset.
Expert for the ONNX MVA method.
Options for the ONNX MVA method.
The Weightfile class serializes all information about a training into an xml tree.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.