Belle II Software development
PIDNeuralNetworkParametersCreatorModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9
10#include <analysis/modules/PIDNeuralNetworkParametersCreator/PIDNeuralNetworkParametersCreatorModule.h>
11
12
13#include <framework/core/ModuleParam.templateDetails.h>
14#include <framework/database/DBImportObjPtr.h>
15#include <framework/logging/Logger.h>
16
17#define FDEEP_FLOAT_TYPE float
18#include <fdeep/fdeep.hpp>
19
20using namespace Belle2;
21
22REG_MODULE(PIDNeuralNetworkParametersCreator);
23
25{
26 // Set module properties
27 setDescription(R"DOC(Module that creates PID neural network parameters and uploads them to the DB)DOC");
28
29 // Parameter definitions
30 addParam("neuralNetworkParametersName", m_neuralNetworkParametersName, "Name of the set of parameters");
31 addParam("description", m_description, "Description of the neural network");
32
33 addParam("inputNames", m_inputNames, "List of name of input variables in the required order");
34 addParam("modelDefinition", m_modelDefinition, "Keras string encoding the neural-network model and parameters");
35 addParam("outputSpeciesPdg", m_outputSpeciesPdg,
36 "List of PDG codes of the hypotheses that correspond to the neural network output probabilities");
37 addParam("meanValues", m_meanValues, "List of mean values of input variables for normalization");
38 addParam("standardDeviations", m_standardDeviations, "List of standard deviations of input variables for normalization");
39 addParam("handleMissingInputs", m_handleMissingInputs,
40 "List of indices and values set if the variable defined by the index is NaN");
41 addParam("inputsToCut", m_inputsToCut, "List of input values that are cut if another input value is in a given range");
42
43 addParam("experimentLow", m_experimentLow, "Interval of validity, exp low");
44 addParam("experimentHigh", m_experimentHigh, "Interval of validity, exp high");
45 addParam("runLow", m_runLow, "Interval of validity, run low");
46 addParam("runHigh", m_runHigh, "Interval of validity, run high");
47}
48
50{
51 bool isValid = true;
52 const size_t nInputs = m_inputNames.size();
53 const size_t nOutputs = m_outputSpeciesPdg.size();
54
55 // this performs some tests and raises an exception if a test fails
56 const auto model = fdeep::read_model_from_string(m_modelDefinition);
57
58 const auto inputShapes = model.get_input_shapes();
59 size_t nModelInputs = 0;
60 for (const auto& shape : inputShapes) {
61 if (shape.rank() != 1) {
62 std::cout << "Can handle only rank=1 inputs, but input has rank " << shape.rank() << std::endl;
63 isValid = false;
64 break;
65 }
66 nModelInputs += shape.minimal_volume();
67 }
68 if (nModelInputs != nInputs) {
69 std::cout << "Model requires " << nModelInputs << " inputs, but parameters have only " << nInputs << " inputs!" << std::endl;
70 isValid = false;
71 }
72
73 const auto outputShapes = model.get_output_shapes();
74 size_t nModelOutputs = 0;
75 for (const auto& shape : outputShapes) {
76 if (shape.rank() != 1) {
77 std::cout << "Can handle only rank=1 outputs, but output has rank " << shape.rank() << std::endl;
78 isValid = false;
79 break;
80 }
81 nModelOutputs += shape.minimal_volume();
82 }
83 if (nModelOutputs != nOutputs) {
84 std::cout << "Model has " << nModelOutputs << " outputs, but parameters have only " << nOutputs << " outputs!" << std::endl;
85 isValid = false;
86 }
87
88 if (nInputs != m_meanValues.size()) {
89 std::cout << "Parameters have " << m_meanValues.size() << " mean values, but " << nInputs << " inputs!";
90 isValid = false;
91 }
92
93 if (nInputs != m_standardDeviations.size()) {
94 std::cout << "Parameters have " << m_standardDeviations.size() << " standard deviations, but " << nInputs << " inputs!";
95 isValid = false;
96 }
97
98 for (auto const& index_value : m_handleMissingInputs) {
99 const auto [index, _] = index_value;
100 if (index >= nInputs) {
101 std::cout << "Index " << index << " of handleMissingInputs out of range!" << std::endl;
102 isValid = false;
103 }
104 }
105
106 for (auto const& inputToCut : m_inputsToCut) {
107 const size_t inputSetIndex = std::get<0>(inputToCut);
108 const size_t inputCutIndex = std::get<1>(inputToCut);
109 if (inputSetIndex >= nInputs) {
110 std::cout << "inputSetIndex " << inputSetIndex << " of handleMissingInputs out of range!" << std::endl;
111 isValid = false;
112 }
113 if (inputCutIndex >= nInputs) {
114 std::cout << "inputCutIndex " << inputCutIndex << " of handleMissingInputs out of range!" << std::endl;
115 isValid = false;
116 }
117 }
118
119 if (!isValid)
120 B2ERROR("The given neural-network parameters are invalid!");
121
123 importer.construct(
132 );
134}
135
136
Class for importing a single object to the database.
A class that describes the interval of experiments/runs for which an object in the database is valid.
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
PIDNNInputsToCut m_inputsToCut
overwrite certain input variables
std::vector< float > m_standardDeviations
standard deviations of inputs
PIDNeuralNetworkParametersCreatorModule()
Constructor: Sets the description, the properties and the parameters of the module.
std::vector< std::string > m_inputNames
list of input names
std::string m_neuralNetworkParametersName
Name of the set of parameters.
std::string m_modelDefinition
neural network string for frugally-deep
std::vector< int > m_outputSpeciesPdg
PDG codes of hypotheses of neural-network output.
std::string m_description
description of neural network parameters
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
Abstract base class for different kinds of events.