Belle II Software development
NeuroTriggerMLPToTextfile.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8/* This program loads a set of trained MLPs from a rootfile
9 * and saves the weights to a textfile with fixed point precision.
10 *
11 * output format:
12 *
13 * isector
14 * ID0min ID0max ID1min ID1max ... ID8min ID8max
15 * sectorpattern patternmask
16 * nodes1 nodes2 ... nodesN
17 * weight1 weight2 ... weightN
18 * ... (repeated for all sectors)
19 *
20 * ID<i>min/max: ID ranges for super layer <i>
21 * sectorpattern/patternmask: for sector selection based on hit pattern
22 * nodes<i>: number of nodes in layer <i>
23 * weight<i>: numerical value of weight <i>, multiplied by 2^precision
24 *
25 * weights connect input nodes of one layer to output nodes of the next layer.
26 * order:
27 * - ordered by layer
28 * - within one layer, ordered by index of output node
29 * - for fixed output node, ordered by index of input node
30 * (last input node = bias node)
31 */
32
33#include <iostream>
34#include <fstream>
35#include <TFile.h>
36#include <TObjArray.h>
37#include <trg/cdc/dataobjects/CDCTriggerMLP.h>
38#include <cmath>
39
40using namespace std;
41using namespace Belle2;
42
43int
44main(int argc, char* argv[])
45{
46 // get arguments
47 if (argc < 3) {
48 cout << "Program needs at least 3 arguments:" << endl
49 << " 1: MLP rootfile" << endl
50 << " 2: precision for MLP weights" << endl
51 << " 3: output filename" << endl;
52 return -1;
53 }
54 string MLPFilename = argv[1];
55 unsigned precisionWeights = atoi(argv[2]);
56
57 TFile MLPFile(MLPFilename.c_str(), "READ");
58 if (!MLPFile.IsOpen()) {
59 cout << "Could not open file " << MLPFilename << endl;
60 return -1;
61 }
62 TObjArray* MLPs = (TObjArray*)MLPFile.Get("MLPs");
63 if (!MLPs) {
64 MLPFile.Close();
65 cout << "File " << MLPFilename << " does not contain key MLPs" << endl;
66 return -1;
67 }
68
69 // load MLPs and write them to file
70 ofstream weightstream(argv[3]);
71 for (int isector = 0; isector < MLPs->GetEntriesFast(); ++isector) {
72 CDCTriggerMLP* expert = dynamic_cast<CDCTriggerMLP*>(MLPs->At(isector));
73 if (!expert) {
74 cout << "Wrong type " << MLPs->At(isector)->ClassName()
75 << ", ignoring this entry." << endl;
76 continue;
77 }
78 // write sector number
79 weightstream << isector << endl;
80 // write ID ranges
81 for (unsigned isl = 0; isl < 9; ++isl) {
82 weightstream << expert->getIDRange(isl)[0] << " "
83 << expert->getIDRange(isl)[1] << " ";
84 }
85 weightstream << endl;
86 // write sector pattern
87 unsigned pattern = expert->getSLpatternUnmasked();
88 weightstream << pattern << " " << expert->getSLpatternMask() << endl;
89 // write number of nodes per network layer
90 vector<unsigned> nNodes = {};
91 for (unsigned il = 0; il < expert->nLayers(); ++il) {
92 nNodes.push_back(expert->nNodesLayer(il));
93 weightstream << nNodes.back() << " ";
94 }
95 weightstream << endl;
96 // write weights and check range
97 vector<float> weights = expert->getWeights();
98 float minWeight = 0;
99 float maxWeight = 0;
100 for (unsigned iw = 0; iw < weights.size(); ++iw) {
101 double weight = weights[iw];
102 // set weights for unused inputs to 0
103 if (iw < ((nNodes[0] + 1) * nNodes[1])) {
104 unsigned isl = (iw % (nNodes[0] + 1)) / 3;
105 if (isl < 9 && !((pattern >> isl) & 1)) weight = 0;
106 }
107 if (weight < minWeight) minWeight = weight;
108 if (weight > maxWeight) maxWeight = weight;
109 weightstream << round(weight * (1 << precisionWeights)) << " ";
110 }
111 weightstream << endl;
112 cout << weights.size() << " weights in [" << minWeight << "," << maxWeight << "]" << endl;
113 }
114 MLPs->Clear();
115 delete MLPs;
116 MLPFile.Close();
117 weightstream.close();
118
119 return 0;
120}
Class to keep all parameters of an expert MLP for the neuro trigger.
Definition: CDCTriggerMLP.h:20
Abstract base class for different kinds of events.
STL namespace.