44 Setup some example options for an ONNX MVA weightfile, save it and compare to reference
60 "daughter(0, chiProb)",
61 "daughter(1, chiProb)",
62 "daughter(0, kaonID)",
63 "daughter(0, pionID)",
64 "daughterAngle(0, 1)",
66 model = nn.Linear(len(variables), n_outputs)
67 model.load_state_dict(weights)
68 general_options = basf2_mva.GeneralOptions()
69 general_options.m_datafiles = basf2_mva.vector(
"dummy")
70 general_options.m_identifier =
"Simple"
71 general_options.m_treename =
"tree"
72 general_options.m_variables = basf2_mva.vector(*variables)
73 specific_options = basf2_mva.ONNXOptions()
75 save_onnx(model, general_options, specific_options, filename)
76 with open(filename)
as f:
78 ref_path = Path(basf2.find_file(
"mva/methods/tests")) / filename
80 with open(ref_path)
as f:
82 except FileNotFoundError:
84 with open(ref_path,
"w")
as f:
86 raise Exception(f
"Wrote new reference file {str(ref_path)}")
87 self.assertEqual(xml_new, xml_ref)
91 Write example for single output
97 "weight": torch.tensor(
98 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
99 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353]]
101 "bias": torch.tensor([0.1927])
107 Write example for multiclass outputs
111 "ONNX_multiclass.xml",
113 "weight": torch.tensor(
114 [[0.1911, 0.2075, -0.0586, 0.2297, -0.0548, 0.0504, -0.1217, 0.1468,
115 0.2204, -0.1834, 0.2173, 0.0468, 0.1847, 0.0339, 0.1205, -0.0353],
116 [0.1927, 0.0370, -0.1167, 0.0637, -0.1152, -0.0293, -0.1015, 0.1658,
117 -0.1973, -0.1153, -0.0706, -0.1503, 0.0236, -0.2469, 0.2258, -0.2124]]
119 "bias": torch.tensor([0.1930, 0.0416])