43 Setup some example options for an ONNX MVA weightfile, save it and compare to reference
59 "daughter(0, chiProb)",
60 "daughter(1, chiProb)",
61 "daughter(0, kaonID)",
62 "daughter(0, pionID)",
63 "daughterAngle(0, 1)",
65 model = nn.Linear(len(variables), n_outputs)
66 model.load_state_dict(weights)
81 with open(filename)
as f:
83 ref_path = Path(basf2.find_file(
"mva/methods/tests")) / filename
85 with open(ref_path)
as f:
87 except FileNotFoundError:
93 with open(ref_path,
"w")
as f:
95 raise Exception(f
"Wrote new reference file {str(ref_path)}")
96 self.assertEqual(xml_new, xml_ref)
100 Write example for single output
106 "weight": torch.tensor(
107 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
108 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353]]
110 "bias": torch.tensor([0.1927])
116 Write example for multiclass outputs
120 "ONNX_multiclass.xml",
122 "weight": torch.tensor(
123 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
124 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353],
125 [0.1927, 0.0370, -0.1167, 0.0637, -0.1152, -0.0293, -0.1015, 0.1658,
126 -0.1973, -0.1153, -0.0706, -0.1503, 0.0236, -0.2469, 0.2258, -0.2124]]
128 "bias": torch.tensor([0.1930, 0.0416])
134 Write example for 3-class outputs
138 "ONNX_multiclass_3.xml",
140 "weight": torch.tensor(
141 [[-0.1648, 0.2103, 0.0204, -0.1267, -0.0719, -0.2464, -0.1342, -0.0418,
142 -0.0362, -0.0801, 0.0587, -0.1121, -0.1560, -0.1602, 0.1597, 0.1568],
143 [-0.2297, -0.1780, -0.0301, -0.2094, -0.1600, 0.1508, 0.1964, 0.1261,
144 -0.0792, 0.0605, -0.0064, 0.0450, 0.0671, -0.2036, 0.0768, 0.0442],
145 [-0.1490, -0.2286, 0.2232, -0.1404, 0.2207, -0.0696, -0.2392, 0.1917,
146 0.0795, -0.1850, 0.0989, -0.0802, 0.0483, 0.0772, 0.1347, -0.1316]]
148 "bias": torch.tensor([-0.1484, 0.1209, -0.0164])