Belle II Software prerelease-10-00-00a
test_write_onnx.py
1#!/usr/bin/env python3
2
3from pathlib import Path
4import unittest
5
6import basf2
7import b2test_utils
8import basf2_mva
9import torch
10from torch import nn
11
12
13def save_onnx(model, general_options, specific_options, identifier):
14 """
15 Export a torch model to onnx and write it into a MVA weightfile
16 """
17 import basf2 # noqa
18 import ROOT
19
20 print("convert to onnx")
21 torch.onnx.export(
22 model,
23 (torch.zeros(1, len(general_options.m_variables)),),
24 "model.onnx",
25 input_names=["input"],
26 output_names=["output"],
27 )
28 wf = ROOT.Belle2.MVA.Weightfile()
29 general_options.m_method = specific_options.getMethod()
30 wf.addOptions(general_options)
31 wf.addOptions(specific_options)
32 wf.addFile("ONNX_Modelfile", "model.onnx")
33 print(f"save to {identifier}")
34 ROOT.Belle2.MVA.Weightfile.save(wf, identifier)
35
36
37class TestWriteONNX(unittest.TestCase):
38 """
39 Tests for writing ONNX MVA weightfiles. In addition to testing the writing
40 mechanism, these serve the purpose of creating test files for other unit tests.
41 """
42 def create_and_save(self, n_outputs, filename, weights):
43 """
44 Setup some example options for an ONNX MVA weightfile, save it and compare to reference
45 """
46
47 # like in mva/tests/all_classifiers.py
48 variables = [
49 "p",
50 "pz",
51 "daughter(0, p)",
52 "daughter(0, pz)",
53 "daughter(1, p)",
54 "daughter(1, pz)",
55 "chiProb",
56 "dr",
57 "dz",
58 "daughter(0, dr)",
59 "daughter(1, dr)",
60 "daughter(0, chiProb)",
61 "daughter(1, chiProb)",
62 "daughter(0, kaonID)",
63 "daughter(0, pionID)",
64 "daughterAngle(0, 1)",
65 ]
66 model = nn.Linear(len(variables), n_outputs)
67 model.load_state_dict(weights)
68 general_options = basf2_mva.GeneralOptions()
69 general_options.m_datafiles = basf2_mva.vector("dummy")
70 general_options.m_identifier = "Simple"
71 general_options.m_treename = "tree"
72 general_options.m_variables = basf2_mva.vector(*variables)
73 specific_options = basf2_mva.ONNXOptions()
75 save_onnx(model, general_options, specific_options, filename)
76 with open(filename) as f:
77 xml_new = f.read()
78 ref_path = Path(basf2.find_file("mva/methods/tests")) / filename
79 try:
80 with open(ref_path) as f:
81 xml_ref = f.read()
82 except FileNotFoundError:
83 # if the file does not exist, recreate it, but still fail the test
84 with open(ref_path, "w") as f:
85 f.write(xml_new)
86 raise Exception(f"Wrote new reference file {str(ref_path)}")
87 self.assertEqual(xml_new, xml_ref)
88
90 """
91 Write example for single output
92 """
93 self.create_and_save(
94 1,
95 "ONNX.xml",
96 {
97 "weight": torch.tensor(
98 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
99 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353]]
100 ),
101 "bias": torch.tensor([0.1927])
102 }
103 )
104
106 """
107 Write example for multiclass outputs
108 """
109 self.create_and_save(
110 2,
111 "ONNX_multiclass.xml",
112 {
113 "weight": torch.tensor(
114 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
115 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353],
116 [0.1927, 0.0370, -0.1167, 0.0637, -0.1152, -0.0293, -0.1015, 0.1658,
117 -0.1973, -0.1153, -0.0706, -0.1503, 0.0236, -0.2469, 0.2258, -0.2124]]
118 ),
119 "bias": torch.tensor([0.1930, 0.0416])
120 }
121 )
122
123
124if __name__ == "__main__":
125 unittest.main()
create_and_save(self, n_outputs, filename, weights)
clean_working_directory()
Definition __init__.py:194