34 Test if we can load legacy (pickled) torch model using a KLMMuonIDDNNWeightFile as example
37 write_dummy_file(list(method.general_options.m_variables))
38 out1, out2 = method.apply_expert(
39 method.general_options.m_datafiles, method.general_options.m_treename
41 self.assertEqual(out1.shape, (10,))
42 self.assertEqual(out2.shape, (10,))
46 Test if writing new files (weights only) and reloading them works
48 variables = [
"var1",
"var2"]
49 general_options = basf2_mva.GeneralOptions()
50 general_options.m_datafiles = basf2_mva.vector(
"dummy.root")
51 general_options.m_identifier =
"Simple.xml"
52 general_options.m_treename =
"tree"
53 general_options.m_variables = basf2_mva.vector(*variables)
54 general_options.m_target_variable =
"target"
56 write_dummy_file(variables, size=100, target_variable=general_options.m_target_variable)
58 specific_options = basf2_mva.PythonOptions()
59 specific_options.m_framework =
"torch"
60 specific_options.m_steering_file =
"dummy.py"
61 specific_options.m_nIterations = 5
62 specific_options.m_mini_batch_size = 8
63 specific_options.m_config = json.dumps({
"learning_rate": 1e-2})
64 specific_options.m_training_fraction = 0.8
65 specific_options.m_normalise =
False
67 with open(
"dummy.py",
"w")
as f:
74 class Model(nn.Module):
75 def __init__(self, number_of_features):
77 self.linear = nn.Linear(number_of_features, 1)
80 return self.linear(x).sigmoid()
83 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
84 state = State(Model(number_of_features).to("cpu"), number_of_features=number_of_features)
85 state.optimizer = torch.optim.SGD(state.model.parameters())
86 state.loss_fn = nn.BCELoss
94 basf2_mva.teacher(general_options, specific_options)
97 out1, out2 = method.apply_expert(
98 method.general_options.m_datafiles, method.general_options.m_treename
100 self.assertEqual(out1.shape, (100,))
101 self.assertEqual(out2.shape, (100,))