Belle II Software light-2509-fornax
test_write_onnx.py
1#!/usr/bin/env python3
2
3from pathlib import Path
4import unittest
5
6import basf2
7import b2test_utils
8from basf2_mva_util import create_onnx_mva_weightfile
9import torch
10from torch import nn
11
12
13def save_onnx(model, filename, variables, **kwargs):
14 """
15 Export a torch model to onnx and write it into a MVA weightfile
16 """
17 import ROOT # noqa
18
19 print("convert to onnx")
20 torch.onnx.export(
21 model,
22 (torch.zeros(1, len(variables)),),
23 "model.onnx",
24 input_names=["input"],
25 output_names=["output"],
26 )
27 weightfile = create_onnx_mva_weightfile("model.onnx", variables=variables, **kwargs)
28 print(f"save to {filename}")
29 weightfile.save(filename)
30
31
32class TestWriteONNX(unittest.TestCase):
33 """
34 Tests for writing ONNX MVA weightfiles. In addition to testing the writing
35 mechanism, these serve the purpose of creating test files for other unit tests.
36 """
37
38
39 maxDiff = None
40
41 def create_and_save(self, n_outputs, filename, weights):
42 """
43 Setup some example options for an ONNX MVA weightfile, save it and compare to reference
44 """
45
46 # like in mva/tests/all_classifiers.py
47 variables = [
48 "p",
49 "pz",
50 "daughter(0, p)",
51 "daughter(0, pz)",
52 "daughter(1, p)",
53 "daughter(1, pz)",
54 "chiProb",
55 "dr",
56 "dz",
57 "daughter(0, dr)",
58 "daughter(1, dr)",
59 "daughter(0, chiProb)",
60 "daughter(1, chiProb)",
61 "daughter(0, kaonID)",
62 "daughter(0, pionID)",
63 "daughterAngle(0, 1)",
64 ]
65 model = nn.Linear(len(variables), n_outputs)
66 model.load_state_dict(weights)
67 if n_outputs > 1:
68 nClasses = n_outputs
69 else:
70 nClasses = 2
72 save_onnx(
73 model,
74 filename,
75 variables=variables,
76 datafiles=["dummy"],
77 identifier="Simple",
78 treename="tree",
79 nClasses=nClasses,
80 )
81 with open(filename) as f:
82 xml_new = f.read()
83 ref_path = Path(basf2.find_file("mva/methods/tests")) / filename
84 try:
85 with open(ref_path) as f:
86 xml_ref = f.read()
87 except FileNotFoundError:
88 # if the file does not exist, recreate it, but still fail the test
89 #
90 # This has to be done when new options are added to ONNXOptions and
91 # therefore the xml changes. In this case, just delete the xmls and
92 # rerun the test to generate new reference files.
93 with open(ref_path, "w") as f:
94 f.write(xml_new)
95 raise Exception(f"Wrote new reference file {str(ref_path)}")
96 self.assertEqual(xml_new, xml_ref)
97
99 """
100 Write example for single output
101 """
102 self.create_and_save(
103 1,
104 "ONNX.xml",
105 {
106 "weight": torch.tensor(
107 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
108 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353]]
109 ),
110 "bias": torch.tensor([0.1927])
111 }
112 )
113
115 """
116 Write example for multiclass outputs
117 """
118 self.create_and_save(
119 2,
120 "ONNX_multiclass.xml",
121 {
122 "weight": torch.tensor(
123 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
124 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353],
125 [0.1927, 0.0370, -0.1167, 0.0637, -0.1152, -0.0293, -0.1015, 0.1658,
126 -0.1973, -0.1153, -0.0706, -0.1503, 0.0236, -0.2469, 0.2258, -0.2124]]
127 ),
128 "bias": torch.tensor([0.1930, 0.0416])
129 }
130 )
131
133 """
134 Write example for 3-class outputs
135 """
136 self.create_and_save(
137 3,
138 "ONNX_multiclass_3.xml",
139 {
140 "weight": torch.tensor(
141 [[-0.1648, 0.2103, 0.0204, -0.1267, -0.0719, -0.2464, -0.1342, -0.0418,
142 -0.0362, -0.0801, 0.0587, -0.1121, -0.1560, -0.1602, 0.1597, 0.1568],
143 [-0.2297, -0.1780, -0.0301, -0.2094, -0.1600, 0.1508, 0.1964, 0.1261,
144 -0.0792, 0.0605, -0.0064, 0.0450, 0.0671, -0.2036, 0.0768, 0.0442],
145 [-0.1490, -0.2286, 0.2232, -0.1404, 0.2207, -0.0696, -0.2392, 0.1917,
146 0.0795, -0.1850, 0.0989, -0.0802, 0.0483, 0.0772, 0.1347, -0.1316]]
147 ),
148 "bias": torch.tensor([-0.1484, 0.1209, -0.0164])
149 }
150 )
151
152
153if __name__ == "__main__":
154 unittest.main()
create_and_save(self, n_outputs, filename, weights)
clean_working_directory()
Definition __init__.py:198