Belle II Software light-2406-ragdoll
test_FANN.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/FANN.h>
10#include <mva/interface/Interface.h>
11#include <framework/utilities/FileSystem.h>
12
13#include <gtest/gtest.h>
14
15using namespace Belle2;
16
17namespace {
18
19 TEST(FANNTest, WeightfilesAreReadCorrectly)
20 {
22
23 MVA::GeneralOptions general_options;
24 general_options.m_variables = {"M", "p", "pt"};
25 MVA::MultiDataset dataset(general_options, {{1.835127, 1.179507, 1.164944},
26 {1.873689, 1.881940, 1.843310},
27 {1.863657, 1.774831, 1.753773},
28 {1.858293, 1.605311, 0.631336},
29 {1.837129, 1.575739, 1.490166},
30 {1.811395, 1.524029, 0.565220}
31 },
32 {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
33
34 auto expert = interface.getExpert();
35
36 auto weightfile = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FANN.xml"));
37 expert->load(weightfile);
38 auto probabilities = expert->apply(dataset);
39 EXPECT_NEAR(probabilities[0], 0.047535836696624756, 0.0001);
40 EXPECT_NEAR(probabilities[1], 0.7130427360534668, 0.0001);
41 EXPECT_NEAR(probabilities[2], 0.7729528546333313, 0.0001);
42 EXPECT_NEAR(probabilities[3], 0.16526281833648682, 0.0001);
43 EXPECT_NEAR(probabilities[4], 0.0091879460960626602, 0.0001);
44 EXPECT_NEAR(probabilities[5], -0.21771839261054993, 0.0001);
45 }
46
47}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:151
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:125
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24