44 def test_write_new_torch(self):
45 """
46 Test if writing new files (weights only) and reloading them works
47 """
48 variables = ["var1", "var2"]
49 general_options = basf2_mva.GeneralOptions()
50 general_options.m_datafiles = basf2_mva.vector("dummy.root")
51 general_options.m_identifier = "Simple.xml"
52 general_options.m_treename = "tree"
53 general_options.m_variables = basf2_mva.vector(*variables)
54 general_options.m_target_variable = "target"
55
56 write_dummy_file(variables, size=100, target_variable=general_options.m_target_variable)
57
58 specific_options = basf2_mva.PythonOptions()
59 specific_options.m_framework = "torch"
60 specific_options.m_steering_file = "dummy.py"
61 specific_options.m_nIterations = 5
62 specific_options.m_mini_batch_size = 8
63 specific_options.m_config = json.dumps({"learning_rate": 1e-2})
64 specific_options.m_training_fraction = 0.8
65 specific_options.m_normalise = False
66
67 with open("dummy.py", "w") as f:
68 f.write(
69 dedent(
70 """
71 import torch
72 from torch import nn
73
74 class Model(nn.Module):
75 def __init__(self, number_of_features):
76 super().__init__()
77 self.linear = nn.Linear(number_of_features, 1)
78
79 def forward(self, x):
80 return self.linear(x).sigmoid()
81
82
83 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
84 state = State(Model(number_of_features).to("cpu"), number_of_features=number_of_features)
85 state.optimizer = torch.optim.SGD(state.model.parameters())
86 state.loss_fn = nn.BCELoss
87 state.epoch = 0
88 state.avg_costs = []
89 return state
90 """
91 )
92 )
93
94 basf2_mva.teacher(general_options, specific_options)
95
97 out1, out2 = method.apply_expert(
98 method.general_options.m_datafiles, method.general_options.m_treename
99 )
100 self.assertEqual(out1.shape, (100,))
101 self.assertEqual(out2.shape, (100,))
102
103