Belle II Software development
trgcdc_3dhjsontoroot.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <string>
10#include <iostream>
11#include <fstream>
12#include <cmath>
13#include <algorithm>
14#include <vector>
15
16#include <nlohmann/json.hpp>
17
18#include "trg/cdc/dataobjects/CDCTriggerHoughMLP.h"
19
20using namespace Belle2;
21
22// Converts a plaintext .json network to a .root network that can be used in basf2
23int main(int argc, const char* argv[])
24{
25 if (argc != 4) {
26 std::cout << "Program requires the following 3 arguments:\n"
27 << " 1: json weights\n"
28 << " 2: configuration file\n"
29 << " 3: root output file name\n";
30 return -1;
31 }
32 const std::string jsonWeights = argv[1];
33 const std::string configFile = argv[2];
34 const std::string outputFile = argv[3];
35
36 NeuroParametersHough neuroParameters3DH = CDCTrigger3DHMLP::loadConfigFromFile(configFile);
37 CDCTrigger3DHMLP mlp(neuroParameters3DH);
38 std::ifstream netfile(jsonWeights, std::ifstream::binary);
39 nlohmann::json network;
40 netfile >> network;
41
42 std::vector<float> floatWeights;
43 const size_t nLayers = neuroParameters3DH.nHidden.size();
44 const size_t layerIndices = (nLayers + 1) * 2;
45 for (size_t layerIdx = 0; layerIdx < layerIndices; layerIdx += 2) {
46 int nodeIdx = 0;
47 const std::string layerWeightName = "model." + std::to_string(layerIdx) + ".weight";
48 const std::string layerBiasName = "model." + std::to_string(layerIdx) + ".bias";
49 for (const auto& node : network[layerWeightName]) {
50 for (float weight : node) {
51 floatWeights.push_back(weight);
52 }
53 float bias = network[layerBiasName][nodeIdx];
54 floatWeights.push_back(bias);
55 ++nodeIdx;
56 }
57 }
58 mlp.setFloatWeights(floatWeights);
59
60 std::cout << "Writing " << floatWeights.size() << " weights\n";
61 mlp.saveMLPToFile(outputFile, "MLP");
62
63 float minFloatWeight = *std::min_element(floatWeights.begin(), floatWeights.end());
64 float maxFloatWeight = *std::max_element(floatWeights.begin(), floatWeights.end());
65 std::cout << "Max weight: " << maxFloatWeight << ", Min weight: " << minFloatWeight << std::endl;
66
67 return 0;
68}
Abstract base class for different kinds of events.