11from gnn_tracking
import CDCNet
16_tCoordinates =
'coordinates'
17_tMomentum =
'momentum'
22def produce_onnx_model():
23 model_path =
'/path/to/model.pt'
24 config_path =
'/path/to/config.yaml'
26 with open(config_path)
as f:
27 config = yaml.safe_load(f)
30 input_dim=len(config[
'dataset'][
'input_features']),
31 k=config[
'model'][
'k'],
32 nblocks=config[
'model'][
'blocks'],
33 coord_dim=config[
'model'][
'coord_dim'],
34 dim1=config[
'model'][
'dim1'],
35 dim2=config[
'model'][
'dim2'],
36 space_dimensions=config[
'model'].get(
'space_dimensions', 4),
37 momentum=config[
'model'].get(
'momentum', 0.6),
40 loaded_model = torch.load(model_path, map_location=
'cpu')
42 state_dict = loaded_model[
'model_state_dict']
43 keys = list(state_dict.keys())
45 for i
in range(len(keys)):
47 if key.startswith(
'module.'):
48 key = key.replace(
'module.',
'', 1)
50 if key.startswith(
'p_ccoords_layer'):
51 key = key.replace(
'p_ccoords_layer',
'p_coords_layer', 1)
52 new_state_dict[key] = state_dict[keys[i]]
54 _net.load_state_dict(new_state_dict)
57 torch.save(_net.state_dict(),
"cdcnet.pt")
60 x = torch.rand(1000, 7)
65 input_names=[_tInput],
66 output_names=[_tBeta, _tCoordinates, _tMomentum, _tVertex, _tCharge],
67 dynamic_shapes=[{0:
"hits"}],
73def produce_payloads():
76 iov = ROOT.Belle2.IntervalOfValidity(0, 0, -1, -1)
77 database = ROOT.Belle2.Database.Instance()
79 onnx_path =
'/path/to/cdcnet.onnx'
81 database.addPayload(
'CATFinderWeightFile', onnx_path, iov)
83 parameters = ROOT.Belle2.CATFinderParameters()
84 parameters.setTDCOffset(4100)
85 parameters.setTDCScale(1100)
86 parameters.setADCClip(600)
87 parameters.setSLayerScale(10)
88 parameters.setCLayerScale(56)
89 parameters.setLayerScale(10)
90 parameters.setSpatialCoordinatesScale(100)
91 parameters.setNInputFeatures(7)
92 parameters.setLatentSpaceNDim(3)
93 parameters.setTBeta(0.3)
94 parameters.setTDistance(0.3)
95 parameters.setMaxRadius(0.15)
96 parameters.setMinNumberHits(7)
97 parameters.setInputTFeaturesName(_tInput)
98 parameters.setOutputTBetaName(_tBeta)
99 parameters.setOutputTCoordinatesName(_tCoordinates)
100 parameters.setOutputTMomentumName(_tMomentum)
101 parameters.setOutputTVertexName(_tVertex)
102 parameters.setOutputTChargeName(_tCharge)
104 database.storeData(
'CATFinderParameters', parameters, iov)
107if __name__ ==
'__main__':