Belle II Software  release-08-01-10
NeuroTriggerMLPToTextfile.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 /* This program loads a set of trained MLPs from a rootfile
9  * and saves the weights to a textfile with fixed point precision.
10  *
11  * output format:
12  *
13  * isector
14  * ID0min ID0max ID1min ID1max ... ID8min ID8max
15  * sectorpattern patternmask
16  * nodes1 nodes2 ... nodesN
17  * weight1 weight2 ... weightN
18  * ... (repeated for all sectors)
19  *
20  * ID<i>min/max: ID ranges for super layer <i>
21  * sectorpattern/patternmask: for sector selection based on hit pattern
22  * nodes<i>: number of nodes in layer <i>
23  * weight<i>: numerical value of weight <i>, multiplied by 2^precision
24  *
25  * weights connect input nodes of one layer to output nodes of the next layer.
26  * order:
27  * - ordered by layer
28  * - within one layer, ordered by index of output node
29  * - for fixed output node, ordered by index of input node
30  * (last input node = bias node)
31  */
32 
33 #include <iostream>
34 #include <fstream>
35 #include <TFile.h>
36 #include <TObjArray.h>
37 #include <trg/cdc/dataobjects/CDCTriggerMLP.h>
38 #include <cmath>
39 
40 using namespace std;
41 using namespace Belle2;
42 
43 int
44 main(int argc, char* argv[])
45 {
46  // get arguments
47  if (argc < 3) {
48  cout << "Program needs at least 3 arguments:" << endl
49  << " 1: MLP rootfile" << endl
50  << " 2: precision for MLP weights" << endl
51  << " 3: output filename" << endl;
52  return -1;
53  }
54  string MLPFilename = argv[1];
55  unsigned precisionWeights = atoi(argv[2]);
56 
57  TFile MLPFile(MLPFilename.c_str(), "READ");
58  if (!MLPFile.IsOpen()) {
59  cout << "Could not open file " << MLPFilename << endl;
60  return -1;
61  }
62  TObjArray* MLPs = (TObjArray*)MLPFile.Get("MLPs");
63  if (!MLPs) {
64  MLPFile.Close();
65  cout << "File " << MLPFilename << " does not contain key MLPs" << endl;
66  return -1;
67  }
68 
69  // load MLPs and write them to file
70  ofstream weightstream(argv[3]);
71  for (int isector = 0; isector < MLPs->GetEntriesFast(); ++isector) {
72  CDCTriggerMLP* expert = dynamic_cast<CDCTriggerMLP*>(MLPs->At(isector));
73  if (!expert) {
74  cout << "Wrong type " << MLPs->At(isector)->ClassName()
75  << ", ignoring this entry." << endl;
76  continue;
77  }
78  // write sector number
79  weightstream << isector << endl;
80  // write ID ranges
81  for (unsigned isl = 0; isl < 9; ++isl) {
82  weightstream << expert->getIDRange(isl)[0] << " "
83  << expert->getIDRange(isl)[1] << " ";
84  }
85  weightstream << endl;
86  // write sector pattern
87  unsigned pattern = expert->getSLpatternUnmasked();
88  weightstream << pattern << " " << expert->getSLpatternMask() << endl;
89  // write number of nodes per network layer
90  vector<unsigned> nNodes = {};
91  for (unsigned il = 0; il < expert->nLayers(); ++il) {
92  nNodes.push_back(expert->nNodesLayer(il));
93  weightstream << nNodes.back() << " ";
94  }
95  weightstream << endl;
96  // write weights and check range
97  vector<float> weights = expert->getWeights();
98  float minWeight = 0;
99  float maxWeight = 0;
100  for (unsigned iw = 0; iw < weights.size(); ++iw) {
101  double weight = weights[iw];
102  // set weights for unused inputs to 0
103  if (iw < ((nNodes[0] + 1) * nNodes[1])) {
104  unsigned isl = (iw % (nNodes[0] + 1)) / 3;
105  if (isl < 9 && !((pattern >> isl) & 1)) weight = 0;
106  }
107  if (weight < minWeight) minWeight = weight;
108  if (weight > maxWeight) maxWeight = weight;
109  weightstream << round(weight * (1 << precisionWeights)) << " ";
110  }
111  weightstream << endl;
112  cout << weights.size() << " weights in [" << minWeight << "," << maxWeight << "]" << endl;
113  }
114  MLPs->Clear();
115  delete MLPs;
116  MLPFile.Close();
117  weightstream.close();
118 
119  return 0;
120 }
Class to keep all parameters of an expert MLP for the neuro trigger.
Definition: CDCTriggerMLP.h:20
Abstract base class for different kinds of events.
int main(int argc, char **argv)
Run all tests.
Definition: test_main.cc:91