Belle II Software  release-08-01-10
PIDNeuralNetworkParametersCreatorModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 
10 #include <analysis/modules/PIDNeuralNetworkParametersCreator/PIDNeuralNetworkParametersCreatorModule.h>
11 
12 
13 #include <framework/core/ModuleParam.templateDetails.h>
14 #include <framework/database/DBImportObjPtr.h>
15 #include <framework/logging/Logger.h>
16 
17 #define FDEEP_FLOAT_TYPE float
18 #include <fdeep/fdeep.hpp>
19 
20 using namespace Belle2;
21 
22 REG_MODULE(PIDNeuralNetworkParametersCreator);
23 
25 {
26  // Set module properties
27  setDescription(R"DOC(Module that creates PID neural network parameters and uploads them to the DB)DOC");
28 
29  // Parameter definitions
30  addParam("neuralNetworkParametersName", m_neuralNetworkParametersName, "Name of the set of parameters");
31  addParam("description", m_description, "Description of the neural network");
32 
33  addParam("inputNames", m_inputNames, "List of name of input variables in the required order");
34  addParam("modelDefinition", m_modelDefinition, "Keras string encoding the neural-network model and parameters");
35  addParam("outputSpeciesPdg", m_outputSpeciesPdg,
36  "List of PDG codes of the hypotheses that correspond to the neural network output probabilities");
37  addParam("meanValues", m_meanValues, "List of mean values of input variables for normalization");
38  addParam("standardDeviations", m_standardDeviations, "List of standard deviations of input variables for normalization");
39  addParam("handleMissingInputs", m_handleMissingInputs,
40  "List of indices and values set if the variable defined by the index is NaN");
41  addParam("inputsToCut", m_inputsToCut, "List of input values that are cut if another input value is in a given range");
42 
43  addParam("experimentLow", m_experimentLow, "Interval of validity, exp low");
44  addParam("experimentHigh", m_experimentHigh, "Interval of validity, exp high");
45  addParam("runLow", m_runLow, "Interval of validity, run low");
46  addParam("runHigh", m_runHigh, "Interval of validity, run high");
47 }
48 
50 {
51  bool isValid = true;
52  const size_t nInputs = m_inputNames.size();
53  const size_t nOutputs = m_outputSpeciesPdg.size();
54 
55  // this performs some tests and raises an exception if a test failes
56  const auto model = fdeep::read_model_from_string(m_modelDefinition);
57 
58  const auto inputShapes = model.get_input_shapes();
59  size_t nModelInputs = 0;
60  for (const auto& shape : inputShapes) {
61  if (shape.rank() != 1) {
62  std::cout << "Can handle only rank=1 inputs, but input has rank " << shape.rank() << std::endl;
63  isValid = false;
64  break;
65  }
66  nModelInputs += shape.minimal_volume();
67  }
68  if (nModelInputs != nInputs) {
69  std::cout << "Model requires " << nModelInputs << " inputs, but parameters have only " << nInputs << " inputs!" << std::endl;
70  isValid = false;
71  }
72 
73  const auto outputShapes = model.get_output_shapes();
74  size_t nModelOutputs = 0;
75  for (const auto& shape : outputShapes) {
76  if (shape.rank() != 1) {
77  std::cout << "Can handle only rank=1 outputs, but output has rank " << shape.rank() << std::endl;
78  isValid = false;
79  break;
80  }
81  nModelOutputs += shape.minimal_volume();
82  }
83  if (nModelOutputs != nOutputs) {
84  std::cout << "Model has " << nModelOutputs << " outputs, but parameters have only " << nOutputs << " outputs!" << std::endl;
85  isValid = false;
86  }
87 
88  if (nInputs != m_meanValues.size()) {
89  std::cout << "Parameters have " << m_meanValues.size() << " mean values, but " << nInputs << " inputs!";
90  isValid = false;
91  }
92 
93  if (nInputs != m_standardDeviations.size()) {
94  std::cout << "Parameters have " << m_standardDeviations.size() << " standard deviations, but " << nInputs << " inputs!";
95  isValid = false;
96  }
97 
98  for (auto const& index_value : m_handleMissingInputs) {
99  const auto [index, _] = index_value;
100  if (index >= nInputs) {
101  std::cout << "Index " << index << " of handleMissingInputs out of range!" << std::endl;
102  isValid = false;
103  }
104  }
105 
106  for (auto const& inputToCut : m_inputsToCut) {
107  const size_t inputSetIndex = std::get<0>(inputToCut);
108  const size_t inputCutIndex = std::get<1>(inputToCut);
109  if (inputSetIndex >= nInputs) {
110  std::cout << "inputSetIndex " << inputSetIndex << " of handleMissingInputs out of range!" << std::endl;
111  isValid = false;
112  }
113  if (inputCutIndex >= nInputs) {
114  std::cout << "inputCutIndex " << inputCutIndex << " of handleMissingInputs out of range!" << std::endl;
115  isValid = false;
116  }
117  }
118 
119  if (!isValid)
120  B2ERROR("The given neural-network parametes are invalid!");
121 
123  importer.construct(
125  m_inputNames,
128  m_meanValues,
132  );
134 }
135 
136 
Class for importing a single object to the database.
A class that describes the interval of experiments/runs for which an object in the database is valid.
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
PIDNNInputsToCut m_inputsToCut
overwrite certain input variables
std::vector< float > m_standardDeviations
standard deviations of inputs
PIDNeuralNetworkParametersCreatorModule()
Constructor: Sets the description, the properties and the parameters of the module.
std::vector< std::string > m_inputNames
list of input names
std::string m_neuralNetworkParametersName
Name of the set of parameters.
std::string m_modelDefinition
neural network string for frugally-deep
std::vector< int > m_outputSpeciesPdg
PDG codes of hypotheses of neural-network output.
std::string m_description
description of neural network parameters
REG_MODULE(arichBtest)
Register the Module.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
Abstract base class for different kinds of events.