9#include <mva/interface/Interface.h>
10#include <mva/methods/ONNX.h>
11#include <framework/utilities/FileSystem.h>
12#include <framework/utilities/TestHelpers.h>
14#include <gtest/gtest.h>
16using namespace Belle2::MVA;
29 expert->load(weightfile);
31 weightfile.getOptions(general_options);
34 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
35 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
38 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
39 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
43 auto probabilities = expert->apply(dataset);
44 EXPECT_NEAR(probabilities[0], 0.3328, 0.0001);
45 EXPECT_NEAR(probabilities[1], 0.5779, 0.0001);
48 TEST(ONNXTest, ONNXExpertMulticlass)
52 expert->load(weightfile);
54 weightfile.getOptions(general_options);
57 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
58 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
61 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
62 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
66 auto probabilities = expert->applyMulticlass(dataset);
67 EXPECT_NEAR(probabilities[0][0], 0.3331, 0.0001);
68 EXPECT_NEAR(probabilities[0][1], -0.5373, 0.0001);
69 EXPECT_NEAR(probabilities[1][0], 0.5782, 0.0001);
70 EXPECT_NEAR(probabilities[1][1], -0.3697, 0.0001);
73 TEST(ONNXTest, ONNXExpertMulticlassThreeClasses)
77 expert->load(weightfile);
79 weightfile.getOptions(general_options);
82 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
83 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
86 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
87 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
91 auto probabilities = expert->applyMulticlass(dataset);
92 EXPECT_NEAR(probabilities[0][0], -0.5394, 0.0001);
93 EXPECT_NEAR(probabilities[0][1], 0.0529, 0.0001);
94 EXPECT_NEAR(probabilities[0][2], -0.0598, 0.0001);
95 EXPECT_NEAR(probabilities[1][0], -0.5089, 0.0001);
96 EXPECT_NEAR(probabilities[1][1], 0.0060, 0.0001);
97 EXPECT_NEAR(probabilities[1][2], 0.0132, 0.0001);
100 Weightfile getONNXWeightfile(std::string modelFilenameONNX, std::string outputName =
"")
105 if (!outputName.empty()) {
106 specific_options.m_outputName = outputName;
108 general_options.m_method = specific_options.getMethod();
111 weightfile.
addFile(
"ONNX_Modelfile", modelFilenameONNX);
115 TEST(ONNXTest, ONNXFatalMultipleInputs)
122 EXPECT_B2FATAL(expert->load(weightfile));
125 TEST(ONNXTest, ONNXFatalMultipleOutputs)
132 EXPECT_B2FATAL(expert->load(weightfile));
135 TEST(ONNXTest, ONNXMultipleOutputsOK)
141 expert->load(weightfile);
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Wraps the data of a multiple event into a Dataset.
Expert for the ONNX MVA method.
Options for the ONNX MVA method.
The Weightfile class serializes all information about a training into an xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.