Belle II Software prerelease-11-00-00a
produce_cat_finder_payloads.py
1
8
9import torch
10import yaml
11from gnn_tracking import CDCNet
12
13
14_tInput = 'input'
15_tBeta = 'beta'
16_tCoordinates = 'coordinates'
17_tMomentum = 'momentum'
18_tVertex = 'vertex'
19_tCharge = 'charge'
20
21
22def produce_onnx_model():
23 model_path = '/path/to/model.pt'
24 config_path = '/path/to/config.yaml'
25
26 with open(config_path) as f:
27 config = yaml.safe_load(f)
28
29 _net = CDCNet(
30 input_dim=len(config['dataset']['input_features']),
31 k=config['model']['k'],
32 nblocks=config['model']['blocks'],
33 coord_dim=config['model']['coord_dim'],
34 dim1=config['model']['dim1'],
35 dim2=config['model']['dim2'],
36 space_dimensions=config['model'].get('space_dimensions', 4),
37 momentum=config['model'].get('momentum', 0.6),
38 ).to('cpu')
39
40 loaded_model = torch.load(model_path, map_location='cpu')
41
42 state_dict = loaded_model['model_state_dict']
43 keys = list(state_dict.keys())
44 new_state_dict = {}
45 for i in range(len(keys)):
46 key = keys[i]
47 if key.startswith('module.'):
48 key = key.replace('module.', '', 1)
49 # Remap old checkpoint name to new attribute name
50 if key.startswith('p_ccoords_layer'):
51 key = key.replace('p_ccoords_layer', 'p_coords_layer', 1)
52 new_state_dict[key] = state_dict[keys[i]]
53
54 _net.load_state_dict(new_state_dict)
55 _net.eval()
56
57 torch.save(_net.state_dict(), "cdcnet.pt")
58
59 with torch.no_grad():
60 x = torch.rand(1000, 7)
61 torch.onnx.export(
62 _net,
63 (x,),
64 "cdcnet.onnx",
65 input_names=[_tInput],
66 output_names=[_tBeta, _tCoordinates, _tMomentum, _tVertex, _tCharge],
67 dynamic_shapes=[{0: "hits"}],
68 dynamo=True,
69 external_data=False,
70 )
71
72
73def produce_payloads():
74 import ROOT # noqa
75
76 iov = ROOT.Belle2.IntervalOfValidity(0, 0, -1, -1)
77 database = ROOT.Belle2.Database.Instance()
78
79 onnx_path = '/path/to/cdcnet.onnx'
80
81 database.addPayload('CATFinderWeightFile', onnx_path, iov)
82
83 parameters = ROOT.Belle2.CATFinderParameters()
84 parameters.setTDCOffset(4100)
85 parameters.setTDCScale(1100)
86 parameters.setADCClip(600)
87 parameters.setSLayerScale(10)
88 parameters.setCLayerScale(56)
89 parameters.setLayerScale(10)
90 parameters.setSpatialCoordinatesScale(100)
91 parameters.setNInputFeatures(7)
92 parameters.setLatentSpaceNDim(3)
93 parameters.setTBeta(0.3)
94 parameters.setTDistance(0.3)
95 parameters.setMaxRadius(0.15)
96 parameters.setMinNumberHits(7)
97 parameters.setInputTFeaturesName(_tInput)
98 parameters.setOutputTBetaName(_tBeta)
99 parameters.setOutputTCoordinatesName(_tCoordinates)
100 parameters.setOutputTMomentumName(_tMomentum)
101 parameters.setOutputTVertexName(_tVertex)
102 parameters.setOutputTChargeName(_tCharge)
103
104 database.storeData('CATFinderParameters', parameters, iov)
105
106
107if __name__ == '__main__':
108 produce_onnx_model()
109 produce_payloads()