Belle II Software prerelease-10-00-00a
test_pytorch.py
1#!/usr/bin/env python3
2
3from textwrap import dedent
4import json
5import unittest
6
7import basf2
8import basf2_mva
9import basf2_mva_util
10
11import numpy as np
12import pandas as pd
13import uproot
14
15
16def write_dummy_file(variables, size=10, target_variable="mcPDG"):
17 data = np.random.normal(size=[size, len(variables) + 1])
18 tree = {}
19 for i, name in enumerate(variables):
20 tree[name] = data[:, i]
21 tree[target_variable] = data[:, -1] > 0.5
22 with uproot.recreate('dummy.root') as outfile:
23 outfile['tree'] = pd.DataFrame(tree, dtype=np.float64)
24
25
26class TestPyTorch(unittest.TestCase):
27 """
28 Tests for the torch python mva method
29 """
31 """
32 Test if we can load legacy (pickled) torch model using a KLMMuonIDDNNWeightFile as example
33 """
34 method = basf2_mva_util.Method(basf2.find_file("mva/methods/tests/KLMMuonIDDNNWeightFile.xml"))
35 write_dummy_file(list(method.general_options.m_variables))
36 out1, out2 = method.apply_expert(
37 method.general_options.m_datafiles, method.general_options.m_treename
38 )
39 self.assertEqual(out1.shape, (10,))
40 self.assertEqual(out2.shape, (10,))
41
43 """
44 Test if writing new files (weights only) and reloading them works
45 """
46 variables = ["var1", "var2"]
47 general_options = basf2_mva.GeneralOptions()
48 general_options.m_datafiles = basf2_mva.vector("dummy.root")
49 general_options.m_identifier = "Simple.xml"
50 general_options.m_treename = "tree"
51 general_options.m_variables = basf2_mva.vector(*variables)
52 general_options.m_target_variable = "target"
53
54 write_dummy_file(variables, size=100, target_variable=general_options.m_target_variable)
55
56 specific_options = basf2_mva.PythonOptions()
57 specific_options.m_framework = "torch"
58 specific_options.m_steering_file = "dummy.py"
59 specific_options.m_nIterations = 5
60 specific_options.m_mini_batch_size = 8
61 specific_options.m_config = json.dumps({"learning_rate": 1e-2})
62 specific_options.m_training_fraction = 0.8
63 specific_options.m_normalise = False
64
65 with open("dummy.py", "w") as f:
66 f.write(
67 dedent(
68 """
69 import torch
70 from torch import nn
71
72 class Model(nn.Module):
73 def __init__(self, number_of_features):
74 super().__init__()
75 self.linear = nn.Linear(number_of_features, 1)
76
77 def forward(self, x):
78 return self.linear(x).sigmoid()
79
80
81 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
82 state = State(Model(number_of_features).to("cpu"), number_of_features=number_of_features)
83 state.optimizer = torch.optim.SGD(state.model.parameters())
84 state.loss_fn = nn.BCELoss
85 state.epoch = 0
86 state.avg_costs = []
87 return state
88 """
89 )
90 )
91
92 basf2_mva.teacher(general_options, specific_options)
93
94 method = basf2_mva_util.Method(general_options.m_identifier)
95 out1, out2 = method.apply_expert(
96 method.general_options.m_datafiles, method.general_options.m_treename
97 )
98 self.assertEqual(out1.shape, (100,))
99 self.assertEqual(out2.shape, (100,))
100
101
102if __name__ == '__main__':
103 import b2test_utils
105 unittest.main()
clean_working_directory()
Definition __init__.py:194