Belle II Software prerelease-10-00-00a
TestWriteONNX Class Reference
Inheritance diagram for TestWriteONNX:
Collaboration diagram for TestWriteONNX:

Public Member Functions

 create_and_save (self, n_outputs, filename, weights)
 
 test_singleclass (self)
 
 test_multiclass (self)
 

Detailed Description

Tests for writing ONNX MVA weightfiles. In addition to testing the writing
mechanism, these serve the purpose of creating test files for other unit tests.

Definition at line 37 of file test_write_onnx.py.

Member Function Documentation

◆ create_and_save()

create_and_save ( self,
n_outputs,
filename,
weights )
Setup some example options for an ONNX MVA weightfile, save it and compare to reference

Definition at line 42 of file test_write_onnx.py.

42 def create_and_save(self, n_outputs, filename, weights):
43 """
44 Setup some example options for an ONNX MVA weightfile, save it and compare to reference
45 """
46
47 # like in mva/tests/all_classifiers.py
48 variables = [
49 "p",
50 "pz",
51 "daughter(0, p)",
52 "daughter(0, pz)",
53 "daughter(1, p)",
54 "daughter(1, pz)",
55 "chiProb",
56 "dr",
57 "dz",
58 "daughter(0, dr)",
59 "daughter(1, dr)",
60 "daughter(0, chiProb)",
61 "daughter(1, chiProb)",
62 "daughter(0, kaonID)",
63 "daughter(0, pionID)",
64 "daughterAngle(0, 1)",
65 ]
66 model = nn.Linear(len(variables), n_outputs)
67 model.load_state_dict(weights)
68 general_options = basf2_mva.GeneralOptions()
69 general_options.m_datafiles = basf2_mva.vector("dummy")
70 general_options.m_identifier = "Simple"
71 general_options.m_treename = "tree"
72 general_options.m_variables = basf2_mva.vector(*variables)
73 specific_options = basf2_mva.ONNXOptions()
75 save_onnx(model, general_options, specific_options, filename)
76 with open(filename) as f:
77 xml_new = f.read()
78 ref_path = Path(basf2.find_file("mva/methods/tests")) / filename
79 try:
80 with open(ref_path) as f:
81 xml_ref = f.read()
82 except FileNotFoundError:
83 # if the file does not exist, recreate it, but still fail the test
84 with open(ref_path, "w") as f:
85 f.write(xml_new)
86 raise Exception(f"Wrote new reference file {str(ref_path)}")
87 self.assertEqual(xml_new, xml_ref)
88
clean_working_directory()
Definition __init__.py:194

◆ test_multiclass()

test_multiclass ( self)
Write example for multiclass outputs

Definition at line 105 of file test_write_onnx.py.

105 def test_multiclass(self):
106 """
107 Write example for multiclass outputs
108 """
109 self.create_and_save(
110 2,
111 "ONNX_multiclass.xml",
112 {
113 "weight": torch.tensor(
114 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
115 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353],
116 [0.1927, 0.0370, -0.1167, 0.0637, -0.1152, -0.0293, -0.1015, 0.1658,
117 -0.1973, -0.1153, -0.0706, -0.1503, 0.0236, -0.2469, 0.2258, -0.2124]]
118 ),
119 "bias": torch.tensor([0.1930, 0.0416])
120 }
121 )
122
123

◆ test_singleclass()

test_singleclass ( self)
Write example for single output

Definition at line 89 of file test_write_onnx.py.

89 def test_singleclass(self):
90 """
91 Write example for single output
92 """
93 self.create_and_save(
94 1,
95 "ONNX.xml",
96 {
97 "weight": torch.tensor(
98 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
99 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353]]
100 ),
101 "bias": torch.tensor([0.1927])
102 }
103 )
104

The documentation for this class was generated from the following file: