Belle II Software development
PIDNeuralNetwork.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10
11#include <analysis/VariableManager/Manager.h>
12#include <analysis/dbobjects/PIDNeuralNetworkParameters.h>
13#include <framework/database/DBObjPtr.h>
14#include <framework/logging/Logger.h>
15
16#include <framework/gearbox/Const.h>
17
18#define FDEEP_FLOAT_TYPE float
19#include <fdeep/fdeep.hpp>
20
21
22
23namespace Belle2 {
33
34 public:
35
41 m_model(nullptr)
42 {
44 check();
45 };
46
47
51 PIDNeuralNetwork(const std::string& parameterName):
54 m_model(nullptr)
55 {
57 check();
58 }
59
66 m_model(std::move(other.m_model)),
69 {
70 }
71
78 std::map<int, double> predict(std::vector<float> inputs) const;
79
84 size_t getInputSize() const { return (*m_pidNeuralNetworkParametersDB)->getInputSize(); }
85
90 const std::vector<std::string>& getInputBasf2Names() const { return m_inputBasf2Names; }
91
97 const std::vector<int>& getOutputSpeciesPdg() const {return (*m_pidNeuralNetworkParametersDB)->getOutputSpeciesPdg();}
98
104
111 bool hasPdgCode(const int pdg, const bool throwException = false) const {return (*m_pidNeuralNetworkParametersDB)->hasPdgCode(pdg, throwException);}
112
117 const std::string& getExtraInfoName(const int pdg) const {return m_extraInfoNames.at(pdg);}
118
119
120 private:
121
129 void check();
130
131 std::string m_pidNeuralNetworkParametersName = "PIDNeuralNetworkParameters";
132 std::unique_ptr<DBObjPtr<PIDNeuralNetworkParameters>> m_pidNeuralNetworkParametersDB;
133 std::unique_ptr<const fdeep::model> m_model;
134 std::vector<std::string> m_inputBasf2Names;
135 std::map<int, std::string> m_extraInfoNames;
137 };
138
140} // Belle2 namespace
141
142
144{
145 for (const auto& name : getInputBasf2Names()) {
146 if (!Variable::Manager::Instance().getVariable(name))
147 B2FATAL("PID neural network needs input '" + name + "', but this input is not available!");
148 }
149
150}
Class to call PID neural network.
PIDNeuralNetwork(const std::string &parameterName)
Constructor with given paramenter-set name.
size_t getInputSize() const
Get number of inputs.
PIDNeuralNetwork()
Constructor with default paramenter-set name.
const std::string & getExtraInfoName(const int pdg) const
std::vector< std::string > m_inputBasf2Names
list of input names of input variables in the basf2 naming scheme
std::string m_pidNeuralNetworkParametersName
name of the parameter set
const std::string & getPIDNeuralNetworkParametersName() const
Get the name of the used neural network.
void loadParametersFromDB()
Load neural-network parameters with name m_pidNeuralNetworkParametersName from the conditions data ba...
std::map< int, double > predict(std::vector< float > inputs) const
Predict neural-network output for all implemented hypotheses using the given inputs.
void check()
Check that Neural Network can be evaluated, e.g.
bool hasPdgCode(const int pdg, const bool throwException=false) const
std::unique_ptr< const fdeep::model > m_model
frugally-deep neural network
const std::vector< std::string > & getInputBasf2Names() const
Get names of input variables in the basf2 naming scheme, which may be different from the one in the p...
PIDNeuralNetwork(PIDNeuralNetwork &&other)
Move constructor.
const std::vector< int > & getOutputSpeciesPdg() const
Get the list of pdg codes of species hypotheses, for which the network predicts the probability in th...
std::map< int, std::string > m_extraInfoNames
map from PDG code to extraInfo name that stores the output of this network
std::unique_ptr< DBObjPtr< PIDNeuralNetworkParameters > > m_pidNeuralNetworkParametersDB
db object for the parameter set
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
Abstract base class for different kinds of events.
STL namespace.