Belle II Software development
test_pytorch.py
1#!/usr/bin/env python3
2
3from textwrap import dedent
4import json
5import unittest
6
7import basf2
8import basf2_mva
9import basf2_mva_util
10
11import numpy as np
12import pandas as pd
13import uproot
14
15
16def write_dummy_file(variables, size=10, target_variable="mcPDG"):
17 # Set a random seed and generate
18 rng = np.random.default_rng(42)
19 data = rng.normal(size=[size, len(variables) + 1])
20 tree = {}
21 for i, name in enumerate(variables):
22 tree[name] = data[:, i]
23 tree[target_variable] = data[:, -1] > 0.5
24 with uproot.recreate('dummy.root') as outfile:
25 outfile['tree'] = pd.DataFrame(tree, dtype=np.float64)
26
27
28class TestPyTorch(unittest.TestCase):
29 """
30 Tests for the torch python mva method
31 """
33 """
34 Test if we can load legacy (pickled) torch model using a KLMMuonIDDNNWeightFile as example
35 """
36 method = basf2_mva_util.Method(basf2.find_file("mva/methods/tests/KLMMuonIDDNNWeightFile.xml"))
37 write_dummy_file(list(method.general_options.m_variables))
38 out1, out2 = method.apply_expert(
39 method.general_options.m_datafiles, method.general_options.m_treename
40 )
41 self.assertEqual(out1.shape, (10,))
42 self.assertEqual(out2.shape, (10,))
43
45 """
46 Test if writing new files (weights only) and reloading them works
47 """
48 variables = ["var1", "var2"]
49 general_options = basf2_mva.GeneralOptions()
50 general_options.m_datafiles = basf2_mva.vector("dummy.root")
51 general_options.m_identifier = "Simple.xml"
52 general_options.m_treename = "tree"
53 general_options.m_variables = basf2_mva.vector(*variables)
54 general_options.m_target_variable = "target"
55
56 write_dummy_file(variables, size=100, target_variable=general_options.m_target_variable)
57
58 specific_options = basf2_mva.PythonOptions()
59 specific_options.m_framework = "torch"
60 specific_options.m_steering_file = "dummy.py"
61 specific_options.m_nIterations = 5
62 specific_options.m_mini_batch_size = 8
63 specific_options.m_config = json.dumps({"learning_rate": 1e-2})
64 specific_options.m_training_fraction = 0.8
65 specific_options.m_normalise = False
66
67 with open("dummy.py", "w") as f:
68 f.write(
69 dedent(
70 """
71 import torch
72 from torch import nn
73
74 class Model(nn.Module):
75 def __init__(self, number_of_features):
76 super().__init__()
77 self.linear = nn.Linear(number_of_features, 1)
78
79 def forward(self, x):
80 return self.linear(x).sigmoid()
81
82
83 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
84 state = State(Model(number_of_features).to("cpu"), number_of_features=number_of_features)
85 state.optimizer = torch.optim.SGD(state.model.parameters())
86 state.loss_fn = nn.BCELoss
87 state.epoch = 0
88 state.avg_costs = []
89 return state
90 """
91 )
92 )
93
94 basf2_mva.teacher(general_options, specific_options)
95
96 method = basf2_mva_util.Method(general_options.m_identifier)
97 out1, out2 = method.apply_expert(
98 method.general_options.m_datafiles, method.general_options.m_treename
99 )
100 self.assertEqual(out1.shape, (100,))
101 self.assertEqual(out2.shape, (100,))
102
103
104if __name__ == '__main__':
105 import b2test_utils
107 unittest.main()
clean_working_directory()
Definition __init__.py:198