Belle II Software development
TestWriteONNX Class Reference
Inheritance diagram for TestWriteONNX:

Public Member Functions

 create_and_save (self, n_outputs, filename, weights)
 
 test_singleclass (self)
 
 test_multiclass (self)
 
 test_multiclass_3 (self)
 

Detailed Description

Tests for writing ONNX MVA weightfiles. In addition to testing the writing
mechanism, these serve the purpose of creating test files for other unit tests.

Definition at line 36 of file test_write_onnx.py.

Member Function Documentation

◆ create_and_save()

create_and_save ( self,
n_outputs,
filename,
weights )
Setup some example options for an ONNX MVA weightfile, save it and compare to reference

Definition at line 42 of file test_write_onnx.py.

42 def create_and_save(self, n_outputs, filename, weights):
43 """
44 Setup some example options for an ONNX MVA weightfile, save it and compare to reference
45 """
46
47 # like in mva/tests/all_classifiers.py
48 variables = [
49 "p",
50 "pz",
51 "daughter(0, p)",
52 "daughter(0, pz)",
53 "daughter(1, p)",
54 "daughter(1, pz)",
55 "chiProb",
56 "dr",
57 "dz",
58 "daughter(0, dr)",
59 "daughter(1, dr)",
60 "daughter(0, chiProb)",
61 "daughter(1, chiProb)",
62 "daughter(0, kaonID)",
63 "daughter(0, pionID)",
64 "daughterAngle(0, 1)",
65 ]
66 model = nn.Linear(len(variables), n_outputs)
67 model.load_state_dict(weights)
68 general_options = basf2_mva.GeneralOptions()
69 general_options.m_datafiles = basf2_mva.vector("dummy")
70 general_options.m_identifier = "Simple"
71 general_options.m_treename = "tree"
72 general_options.m_variables = basf2_mva.vector(*variables)
73 if n_outputs > 1:
74 general_options.m_nClasses = n_outputs
75 specific_options = basf2_mva.ONNXOptions()
77 save_onnx(model, general_options, specific_options, filename)
78 with open(filename) as f:
79 xml_new = f.read()
80 ref_path = Path(basf2.find_file("mva/methods/tests")) / filename
81 try:
82 with open(ref_path) as f:
83 xml_ref = f.read()
84 except FileNotFoundError:
85 # if the file does not exist, recreate it, but still fail the test
86 with open(ref_path, "w") as f:
87 f.write(xml_new)
88 raise Exception(f"Wrote new reference file {str(ref_path)}")
89 self.assertEqual(xml_new, xml_ref)
90
clean_working_directory()
Definition __init__.py:198

◆ test_multiclass()

test_multiclass ( self)
Write example for multiclass outputs

Definition at line 107 of file test_write_onnx.py.

107 def test_multiclass(self):
108 """
109 Write example for multiclass outputs
110 """
111 self.create_and_save(
112 2,
113 "ONNX_multiclass.xml",
114 {
115 "weight": torch.tensor(
116 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
117 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353],
118 [0.1927, 0.0370, -0.1167, 0.0637, -0.1152, -0.0293, -0.1015, 0.1658,
119 -0.1973, -0.1153, -0.0706, -0.1503, 0.0236, -0.2469, 0.2258, -0.2124]]
120 ),
121 "bias": torch.tensor([0.1930, 0.0416])
122 }
123 )
124

◆ test_multiclass_3()

test_multiclass_3 ( self)
Write example for 3-class outputs

Definition at line 125 of file test_write_onnx.py.

125 def test_multiclass_3(self):
126 """
127 Write example for 3-class outputs
128 """
129 self.create_and_save(
130 3,
131 "ONNX_multiclass_3.xml",
132 {
133 "weight": torch.tensor(
134 [[-0.1648, 0.2103, 0.0204, -0.1267, -0.0719, -0.2464, -0.1342, -0.0418,
135 -0.0362, -0.0801, 0.0587, -0.1121, -0.1560, -0.1602, 0.1597, 0.1568],
136 [-0.2297, -0.1780, -0.0301, -0.2094, -0.1600, 0.1508, 0.1964, 0.1261,
137 -0.0792, 0.0605, -0.0064, 0.0450, 0.0671, -0.2036, 0.0768, 0.0442],
138 [-0.1490, -0.2286, 0.2232, -0.1404, 0.2207, -0.0696, -0.2392, 0.1917,
139 0.0795, -0.1850, 0.0989, -0.0802, 0.0483, 0.0772, 0.1347, -0.1316]]
140 ),
141 "bias": torch.tensor([-0.1484, 0.1209, -0.0164])
142 }
143 )
144
145

◆ test_singleclass()

test_singleclass ( self)
Write example for single output

Definition at line 91 of file test_write_onnx.py.

91 def test_singleclass(self):
92 """
93 Write example for single output
94 """
95 self.create_and_save(
96 1,
97 "ONNX.xml",
98 {
99 "weight": torch.tensor(
100 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
101 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353]]
102 ),
103 "bias": torch.tensor([0.1927])
104 }
105 )
106

The documentation for this class was generated from the following file: