Belle II Software development
test_write_onnx.py
1#!/usr/bin/env python3
2
3from pathlib import Path
4import unittest
5
6import basf2
7import b2test_utils
8import basf2_mva
9import torch
10from torch import nn
11
12
13def save_onnx(model, general_options, specific_options, identifier):
14 """
15 Export a torch model to onnx and write it into a MVA weightfile
16 """
17 import ROOT # noqa
18
19 print("convert to onnx")
20 torch.onnx.export(
21 model,
22 (torch.zeros(1, len(general_options.m_variables)),),
23 "model.onnx",
24 input_names=["input"],
25 output_names=["output"],
26 )
27 wf = ROOT.Belle2.MVA.Weightfile()
28 general_options.m_method = specific_options.getMethod()
29 wf.addOptions(general_options)
30 wf.addOptions(specific_options)
31 wf.addFile("ONNX_Modelfile", "model.onnx")
32 print(f"save to {identifier}")
33 ROOT.Belle2.MVA.Weightfile.save(wf, identifier)
34
35
36class TestWriteONNX(unittest.TestCase):
37 """
38 Tests for writing ONNX MVA weightfiles. In addition to testing the writing
39 mechanism, these serve the purpose of creating test files for other unit tests.
40 """
41
42 def create_and_save(self, n_outputs, filename, weights):
43 """
44 Setup some example options for an ONNX MVA weightfile, save it and compare to reference
45 """
46
47 # like in mva/tests/all_classifiers.py
48 variables = [
49 "p",
50 "pz",
51 "daughter(0, p)",
52 "daughter(0, pz)",
53 "daughter(1, p)",
54 "daughter(1, pz)",
55 "chiProb",
56 "dr",
57 "dz",
58 "daughter(0, dr)",
59 "daughter(1, dr)",
60 "daughter(0, chiProb)",
61 "daughter(1, chiProb)",
62 "daughter(0, kaonID)",
63 "daughter(0, pionID)",
64 "daughterAngle(0, 1)",
65 ]
66 model = nn.Linear(len(variables), n_outputs)
67 model.load_state_dict(weights)
68 general_options = basf2_mva.GeneralOptions()
69 general_options.m_datafiles = basf2_mva.vector("dummy")
70 general_options.m_identifier = "Simple"
71 general_options.m_treename = "tree"
72 general_options.m_variables = basf2_mva.vector(*variables)
73 if n_outputs > 1:
74 general_options.m_nClasses = n_outputs
75 specific_options = basf2_mva.ONNXOptions()
77 save_onnx(model, general_options, specific_options, filename)
78 with open(filename) as f:
79 xml_new = f.read()
80 ref_path = Path(basf2.find_file("mva/methods/tests")) / filename
81 try:
82 with open(ref_path) as f:
83 xml_ref = f.read()
84 except FileNotFoundError:
85 # if the file does not exist, recreate it, but still fail the test
86 with open(ref_path, "w") as f:
87 f.write(xml_new)
88 raise Exception(f"Wrote new reference file {str(ref_path)}")
89 self.assertEqual(xml_new, xml_ref)
90
92 """
93 Write example for single output
94 """
95 self.create_and_save(
96 1,
97 "ONNX.xml",
98 {
99 "weight": torch.tensor(
100 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
101 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353]]
102 ),
103 "bias": torch.tensor([0.1927])
104 }
105 )
106
108 """
109 Write example for multiclass outputs
110 """
111 self.create_and_save(
112 2,
113 "ONNX_multiclass.xml",
114 {
115 "weight": torch.tensor(
116 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
117 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353],
118 [0.1927, 0.0370, -0.1167, 0.0637, -0.1152, -0.0293, -0.1015, 0.1658,
119 -0.1973, -0.1153, -0.0706, -0.1503, 0.0236, -0.2469, 0.2258, -0.2124]]
120 ),
121 "bias": torch.tensor([0.1930, 0.0416])
122 }
123 )
124
126 """
127 Write example for 3-class outputs
128 """
129 self.create_and_save(
130 3,
131 "ONNX_multiclass_3.xml",
132 {
133 "weight": torch.tensor(
134 [[-0.1648, 0.2103, 0.0204, -0.1267, -0.0719, -0.2464, -0.1342, -0.0418,
135 -0.0362, -0.0801, 0.0587, -0.1121, -0.1560, -0.1602, 0.1597, 0.1568],
136 [-0.2297, -0.1780, -0.0301, -0.2094, -0.1600, 0.1508, 0.1964, 0.1261,
137 -0.0792, 0.0605, -0.0064, 0.0450, 0.0671, -0.2036, 0.0768, 0.0442],
138 [-0.1490, -0.2286, 0.2232, -0.1404, 0.2207, -0.0696, -0.2392, 0.1917,
139 0.0795, -0.1850, 0.0989, -0.0802, 0.0483, 0.0772, 0.1347, -0.1316]]
140 ),
141 "bias": torch.tensor([-0.1484, 0.1209, -0.0164])
142 }
143 )
144
145
146if __name__ == "__main__":
147 unittest.main()
create_and_save(self, n_outputs, filename, weights)
clean_working_directory()
Definition __init__.py:198