32 Test if we can load legacy (pickled) torch model using a KLMMuonIDDNNWeightFile as example
35 write_dummy_file(list(method.general_options.m_variables))
36 out1, out2 = method.apply_expert(
37 method.general_options.m_datafiles, method.general_options.m_treename
39 self.assertEqual(out1.shape, (10,))
40 self.assertEqual(out2.shape, (10,))
44 Test if writing new files (weights only) and reloading them works
46 variables = [
"var1",
"var2"]
47 general_options = basf2_mva.GeneralOptions()
48 general_options.m_datafiles = basf2_mva.vector(
"dummy.root")
49 general_options.m_identifier =
"Simple.xml"
50 general_options.m_treename =
"tree"
51 general_options.m_variables = basf2_mva.vector(*variables)
52 general_options.m_target_variable =
"target"
54 write_dummy_file(variables, size=100, target_variable=general_options.m_target_variable)
56 specific_options = basf2_mva.PythonOptions()
57 specific_options.m_framework =
"torch"
58 specific_options.m_steering_file =
"dummy.py"
59 specific_options.m_nIterations = 5
60 specific_options.m_mini_batch_size = 8
61 specific_options.m_config = json.dumps({
"learning_rate": 1e-2})
62 specific_options.m_training_fraction = 0.8
63 specific_options.m_normalise =
False
65 with open(
"dummy.py",
"w")
as f:
72 class Model(nn.Module):
73 def __init__(self, number_of_features):
75 self.linear = nn.Linear(number_of_features, 1)
78 return self.linear(x).sigmoid()
81 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
82 state = State(Model(number_of_features).to("cpu"), number_of_features=number_of_features)
83 state.optimizer = torch.optim.SGD(state.model.parameters())
84 state.loss_fn = nn.BCELoss
92 basf2_mva.teacher(general_options, specific_options)
95 out1, out2 = method.apply_expert(
96 method.general_options.m_datafiles, method.general_options.m_treename
98 self.assertEqual(out1.shape, (100,))
99 self.assertEqual(out2.shape, (100,))