Belle II Software prerelease-10-00-00a
TestPyTorch Class Reference
Inheritance diagram for TestPyTorch:
Collaboration diagram for TestPyTorch:

Public Member Functions

 test_load_and_apply_existing_torch (self)
 
 test_write_new_torch (self)
 

Detailed Description

Tests for the torch python mva method

Definition at line 26 of file test_pytorch.py.

Member Function Documentation

◆ test_load_and_apply_existing_torch()

test_load_and_apply_existing_torch ( self)
Test if we can load legacy (pickled) torch model using a KLMMuonIDDNNWeightFile as example

Definition at line 30 of file test_pytorch.py.

30 def test_load_and_apply_existing_torch(self):
31 """
32 Test if we can load legacy (pickled) torch model using a KLMMuonIDDNNWeightFile as example
33 """
34 method = basf2_mva_util.Method(basf2.find_file("mva/methods/tests/KLMMuonIDDNNWeightFile.xml"))
35 write_dummy_file(list(method.general_options.m_variables))
36 out1, out2 = method.apply_expert(
37 method.general_options.m_datafiles, method.general_options.m_treename
38 )
39 self.assertEqual(out1.shape, (10,))
40 self.assertEqual(out2.shape, (10,))
41

◆ test_write_new_torch()

test_write_new_torch ( self)
Test if writing new files (weights only) and reloading them works

Definition at line 42 of file test_pytorch.py.

42 def test_write_new_torch(self):
43 """
44 Test if writing new files (weights only) and reloading them works
45 """
46 variables = ["var1", "var2"]
47 general_options = basf2_mva.GeneralOptions()
48 general_options.m_datafiles = basf2_mva.vector("dummy.root")
49 general_options.m_identifier = "Simple.xml"
50 general_options.m_treename = "tree"
51 general_options.m_variables = basf2_mva.vector(*variables)
52 general_options.m_target_variable = "target"
53
54 write_dummy_file(variables, size=100, target_variable=general_options.m_target_variable)
55
56 specific_options = basf2_mva.PythonOptions()
57 specific_options.m_framework = "torch"
58 specific_options.m_steering_file = "dummy.py"
59 specific_options.m_nIterations = 5
60 specific_options.m_mini_batch_size = 8
61 specific_options.m_config = json.dumps({"learning_rate": 1e-2})
62 specific_options.m_training_fraction = 0.8
63 specific_options.m_normalise = False
64
65 with open("dummy.py", "w") as f:
66 f.write(
67 dedent(
68 """
69 import torch
70 from torch import nn
71
72 class Model(nn.Module):
73 def __init__(self, number_of_features):
74 super().__init__()
75 self.linear = nn.Linear(number_of_features, 1)
76
77 def forward(self, x):
78 return self.linear(x).sigmoid()
79
80
81 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
82 state = State(Model(number_of_features).to("cpu"), number_of_features=number_of_features)
83 state.optimizer = torch.optim.SGD(state.model.parameters())
84 state.loss_fn = nn.BCELoss
85 state.epoch = 0
86 state.avg_costs = []
87 return state
88 """
89 )
90 )
91
92 basf2_mva.teacher(general_options, specific_options)
93
94 method = basf2_mva_util.Method(general_options.m_identifier)
95 out1, out2 = method.apply_expert(
96 method.general_options.m_datafiles, method.general_options.m_treename
97 )
98 self.assertEqual(out1.shape, (100,))
99 self.assertEqual(out2.shape, (100,))
100
101

The documentation for this class was generated from the following file: