Belle II Software development
simple.py
1#!/usr/bin/env python3
2
3
10
11import os
12from tempfile import TemporaryDirectory
13
14import basf2 # noqa
15import torch
16from torch import nn
17import numpy as np
18import uproot
19
20import ROOT
21
22
23class Model(nn.Module):
24 """
25 My dense neural network
26 """
27
28 def __init__(self, number_of_features):
29 """
30 Parameters:
31 number_of_features: number of input features
32 """
33 super().__init__()
34
35
36 self.network = nn.Sequential(
37 nn.Linear(number_of_features, 128),
38 nn.ReLU(),
39 nn.Linear(128, 128),
40 nn.ReLU(),
41 nn.Linear(128, 1),
42 nn.Sigmoid(),
43 )
44
45 def forward(self, x):
46 """
47 Run the network
48 """
49 prob = self.network(x)
50 return prob
51
52
53def fit(model, filename, treename, variables, target_variable):
54 with uproot.open({filename: treename}) as tree:
55 X = tree.arrays(
56 map(ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible, variables),
57 library="pd",
58 ).to_numpy()
59 y = tree[target_variable].array(library="np")
60 ds = torch.utils.data.TensorDataset(
61 torch.tensor(X, dtype=torch.float32),
62 torch.tensor(y, dtype=torch.float32)[:, np.newaxis]
63 )
64 dl = torch.utils.data.DataLoader(ds, batch_size=256, shuffle=True)
65 opt = torch.optim.Adam(model.parameters())
66 for epoch in range(50):
67 print(f"Epoch {epoch}", end=", ")
68 losses = []
69 for bx, by in dl:
70 opt.zero_grad()
71 p = model(bx)
72 loss = torch.nn.functional.binary_cross_entropy(p, by)
73 loss.backward()
74 opt.step()
75 losses.append(loss.detach().item())
76 print(f"Loss = {np.mean(losses)}", end="\r")
77 print()
78
79
80def save_onnx_to_database(model, general_options, specific_options, identifier):
81 with TemporaryDirectory() as tempdir:
82 cwd = os.getcwd()
83 os.chdir(tempdir)
84 print("convert to onnx")
85 torch.onnx.export(
86 model,
87 (torch.randn(1, len(general_options.m_variables)),),
88 "model.onnx",
89 input_names=["input"],
90 output_names=["output"],
91 )
92 wf = ROOT.Belle2.MVA.Weightfile()
93 general_options.m_method = specific_options.getMethod()
94 wf.addOptions(general_options)
95 wf.addOptions(specific_options)
96 wf.addFile("ONNX_Modelfile", "model.onnx")
97 os.chdir(cwd)
98 print("save to database")
99 ROOT.Belle2.MVA.Weightfile.saveToDatabase(wf, identifier)
100
101
102if __name__ == "__main__":
103 import time
104
105 import basf2_mva
106 import basf2_mva_util
107 from basf2 import conditions
108 from basf2 import find_file
109
110 conditions.testing_payloads = [
111 'localdb/database.txt'
112 ]
113
114 train_file = find_file("mva/train_D0toKpipi.root", "examples")
115 test_file = find_file("mva/test_D0toKpipi.root", "examples")
116
117 general_options = basf2_mva.GeneralOptions()
118 general_options.m_datafiles = basf2_mva.vector(train_file)
119 general_options.m_identifier = "Simple"
120 general_options.m_treename = "tree"
121 variables = ['M', 'p', 'pt', 'pz',
122 'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)',
123 'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)',
124 'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)',
125 'chiProb', 'dr', 'dz',
126 'daughter(0, dr)', 'daughter(1, dr)',
127 'daughter(0, dz)', 'daughter(1, dz)',
128 'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
129 'daughter(0, kaonID)', 'daughter(0, pionID)',
130 'daughterInvM(0, 1)', 'daughterInvM(0, 2)', 'daughterInvM(1, 2)']
131 general_options.m_variables = basf2_mva.vector(*variables)
132 general_options.m_target_variable = "isSignal"
133
134 specific_options = basf2_mva.ONNXOptions()
135
136 model = Model(len(variables))
137 fit(
138 model,
139 train_file,
140 general_options.m_treename,
141 variables,
142 general_options.m_target_variable,
143 )
144 save_onnx_to_database(model, general_options, specific_options, "ONNXTest")
145
146 method = basf2_mva_util.Method("ONNXTest")
147 inference_start = time.time()
148 test_data = [test_file]
149 p, t = method.apply_expert(basf2_mva.vector(*test_data), general_options.m_treename)
150 inference_stop = time.time()
151 inference_time = inference_stop - inference_start
153 print("ONNX", inference_time, auc)
calculate_auc_efficiency_vs_background_retention(p, t, w=None)
STL class.
STL class.
network
a dense model with one hidden layer
Definition simple.py:36
__init__(self, number_of_features)
Definition simple.py:28
forward(self, x)
Definition simple.py:45