44 Setup some example options for an ONNX MVA weightfile, save it and compare to reference
60 "daughter(0, chiProb)",
61 "daughter(1, chiProb)",
62 "daughter(0, kaonID)",
63 "daughter(0, pionID)",
64 "daughterAngle(0, 1)",
66 model = nn.Linear(len(variables), n_outputs)
67 model.load_state_dict(weights)
68 general_options = basf2_mva.GeneralOptions()
69 general_options.m_datafiles = basf2_mva.vector(
"dummy")
70 general_options.m_identifier =
"Simple"
71 general_options.m_treename =
"tree"
72 general_options.m_variables = basf2_mva.vector(*variables)
74 general_options.m_nClasses = n_outputs
75 specific_options = basf2_mva.ONNXOptions()
77 save_onnx(model, general_options, specific_options, filename)
78 with open(filename)
as f:
80 ref_path = Path(basf2.find_file(
"mva/methods/tests")) / filename
82 with open(ref_path)
as f:
84 except FileNotFoundError:
86 with open(ref_path,
"w")
as f:
88 raise Exception(f
"Wrote new reference file {str(ref_path)}")
89 self.assertEqual(xml_new, xml_ref)
93 Write example for single output
99 "weight": torch.tensor(
100 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
101 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353]]
103 "bias": torch.tensor([0.1927])
109 Write example for multiclass outputs
113 "ONNX_multiclass.xml",
115 "weight": torch.tensor(
116 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
117 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353],
118 [0.1927, 0.0370, -0.1167, 0.0637, -0.1152, -0.0293, -0.1015, 0.1658,
119 -0.1973, -0.1153, -0.0706, -0.1503, 0.0236, -0.2469, 0.2258, -0.2124]]
121 "bias": torch.tensor([0.1930, 0.0416])
127 Write example for 3-class outputs
131 "ONNX_multiclass_3.xml",
133 "weight": torch.tensor(
134 [[-0.1648, 0.2103, 0.0204, -0.1267, -0.0719, -0.2464, -0.1342, -0.0418,
135 -0.0362, -0.0801, 0.0587, -0.1121, -0.1560, -0.1602, 0.1597, 0.1568],
136 [-0.2297, -0.1780, -0.0301, -0.2094, -0.1600, 0.1508, 0.1964, 0.1261,
137 -0.0792, 0.0605, -0.0064, 0.0450, 0.0671, -0.2036, 0.0768, 0.0442],
138 [-0.1490, -0.2286, 0.2232, -0.1404, 0.2207, -0.0696, -0.2392, 0.1917,
139 0.0795, -0.1850, 0.0989, -0.0802, 0.0483, 0.0772, 0.1347, -0.1316]]
141 "bias": torch.tensor([-0.1484, 0.1209, -0.0164])