Belle II Software light-2509-fornax
TestWriteONNX Class Reference
Inheritance diagram for TestWriteONNX:
Collaboration diagram for TestWriteONNX:

Public Member Functions

 create_and_save (self, n_outputs, filename, weights)
 
 test_singleclass (self)
 
 test_multiclass (self)
 
 test_multiclass_3 (self)
 

Static Public Attributes

 maxDiff = None
 show long diff in case of mismatching xml files
 

Detailed Description

Tests for writing ONNX MVA weightfiles. In addition to testing the writing
mechanism, these serve the purpose of creating test files for other unit tests.

Definition at line 32 of file test_write_onnx.py.

Member Function Documentation

◆ create_and_save()

create_and_save ( self,
n_outputs,
filename,
weights )
Setup some example options for an ONNX MVA weightfile, save it and compare to reference

Definition at line 41 of file test_write_onnx.py.

41 def create_and_save(self, n_outputs, filename, weights):
42 """
43 Setup some example options for an ONNX MVA weightfile, save it and compare to reference
44 """
45
46 # like in mva/tests/all_classifiers.py
47 variables = [
48 "p",
49 "pz",
50 "daughter(0, p)",
51 "daughter(0, pz)",
52 "daughter(1, p)",
53 "daughter(1, pz)",
54 "chiProb",
55 "dr",
56 "dz",
57 "daughter(0, dr)",
58 "daughter(1, dr)",
59 "daughter(0, chiProb)",
60 "daughter(1, chiProb)",
61 "daughter(0, kaonID)",
62 "daughter(0, pionID)",
63 "daughterAngle(0, 1)",
64 ]
65 model = nn.Linear(len(variables), n_outputs)
66 model.load_state_dict(weights)
67 if n_outputs > 1:
68 nClasses = n_outputs
69 else:
70 nClasses = 2
72 save_onnx(
73 model,
74 filename,
75 variables=variables,
76 datafiles=["dummy"],
77 identifier="Simple",
78 treename="tree",
79 nClasses=nClasses,
80 )
81 with open(filename) as f:
82 xml_new = f.read()
83 ref_path = Path(basf2.find_file("mva/methods/tests")) / filename
84 try:
85 with open(ref_path) as f:
86 xml_ref = f.read()
87 except FileNotFoundError:
88 # if the file does not exist, recreate it, but still fail the test
89 #
90 # This has to be done when new options are added to ONNXOptions and
91 # therefore the xml changes. In this case, just delete the xmls and
92 # rerun the test to generate new reference files.
93 with open(ref_path, "w") as f:
94 f.write(xml_new)
95 raise Exception(f"Wrote new reference file {str(ref_path)}")
96 self.assertEqual(xml_new, xml_ref)
97
clean_working_directory()
Definition __init__.py:198

◆ test_multiclass()

test_multiclass ( self)
Write example for multiclass outputs

Definition at line 114 of file test_write_onnx.py.

114 def test_multiclass(self):
115 """
116 Write example for multiclass outputs
117 """
118 self.create_and_save(
119 2,
120 "ONNX_multiclass.xml",
121 {
122 "weight": torch.tensor(
123 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
124 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353],
125 [0.1927, 0.0370, -0.1167, 0.0637, -0.1152, -0.0293, -0.1015, 0.1658,
126 -0.1973, -0.1153, -0.0706, -0.1503, 0.0236, -0.2469, 0.2258, -0.2124]]
127 ),
128 "bias": torch.tensor([0.1930, 0.0416])
129 }
130 )
131

◆ test_multiclass_3()

test_multiclass_3 ( self)
Write example for 3-class outputs

Definition at line 132 of file test_write_onnx.py.

132 def test_multiclass_3(self):
133 """
134 Write example for 3-class outputs
135 """
136 self.create_and_save(
137 3,
138 "ONNX_multiclass_3.xml",
139 {
140 "weight": torch.tensor(
141 [[-0.1648, 0.2103, 0.0204, -0.1267, -0.0719, -0.2464, -0.1342, -0.0418,
142 -0.0362, -0.0801, 0.0587, -0.1121, -0.1560, -0.1602, 0.1597, 0.1568],
143 [-0.2297, -0.1780, -0.0301, -0.2094, -0.1600, 0.1508, 0.1964, 0.1261,
144 -0.0792, 0.0605, -0.0064, 0.0450, 0.0671, -0.2036, 0.0768, 0.0442],
145 [-0.1490, -0.2286, 0.2232, -0.1404, 0.2207, -0.0696, -0.2392, 0.1917,
146 0.0795, -0.1850, 0.0989, -0.0802, 0.0483, 0.0772, 0.1347, -0.1316]]
147 ),
148 "bias": torch.tensor([-0.1484, 0.1209, -0.0164])
149 }
150 )
151
152

◆ test_singleclass()

test_singleclass ( self)
Write example for single output

Definition at line 98 of file test_write_onnx.py.

98 def test_singleclass(self):
99 """
100 Write example for single output
101 """
102 self.create_and_save(
103 1,
104 "ONNX.xml",
105 {
106 "weight": torch.tensor(
107 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
108 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353]]
109 ),
110 "bias": torch.tensor([0.1927])
111 }
112 )
113

Member Data Documentation

◆ maxDiff

maxDiff = None
static

show long diff in case of mismatching xml files

Definition at line 39 of file test_write_onnx.py.


The documentation for this class was generated from the following file: