9#include <mva/interface/Interface.h>
10#include <mva/methods/ONNX.h>
11#include <framework/utilities/FileSystem.h>
13#include <gtest/gtest.h>
15using namespace Belle2::MVA;
27 expert->load(weightfile);
29 weightfile.getOptions(general_options);
32 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
33 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
36 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
37 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
41 auto probabilities = expert->apply(dataset);
42 EXPECT_NEAR(probabilities[0], 0.3328, 0.0001);
43 EXPECT_NEAR(probabilities[1], 0.5779, 0.0001);
46 TEST(ONNXTest, ONNXExpertMulticlass)
51 expert->load(weightfile);
53 weightfile.getOptions(general_options);
56 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
57 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
60 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
61 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
65 auto probabilities = expert->applyMulticlass(dataset);
66 EXPECT_NEAR(probabilities[0][0], 0.3331, 0.0001);
67 EXPECT_NEAR(probabilities[0][1], -0.5373, 0.0001);
68 EXPECT_NEAR(probabilities[1][0], 0.5782, 0.0001);
69 EXPECT_NEAR(probabilities[1][1], -0.3697, 0.0001);
72 TEST(ONNXTest, ONNXExpertMulticlassThreeClasses)
77 expert->load(weightfile);
79 weightfile.getOptions(general_options);
82 0.338, 0.079, 0.16, 0.048, 0.877, 0.367, 0.5, 0.436,
83 0.33, 0.76, 0.176, 0.899, 0.062, 0.794, 0.477, 0.725
86 0.438, 0.222, 0.959, 0.551, 0.987, 0.509, 0.141, 0.005, 0.387,
87 0.926, 0.099, 0.990, 0.870, 0.050, 0.924, 0.767
91 auto probabilities = expert->applyMulticlass(dataset);
92 EXPECT_NEAR(probabilities[0][0], -0.5394, 0.0001);
93 EXPECT_NEAR(probabilities[0][1], 0.0529, 0.0001);
94 EXPECT_NEAR(probabilities[0][2], -0.0598, 0.0001);
95 EXPECT_NEAR(probabilities[1][0], -0.5089, 0.0001);
96 EXPECT_NEAR(probabilities[1][1], 0.0060, 0.0001);
97 EXPECT_NEAR(probabilities[1][2], 0.0132, 0.0001);
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Wraps the data of a multiple event into a Dataset.
Expert for the ONNX MVA method.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.