Belle II Software development
SmartBKGTrain.py
1
8import torch
9import awkward as ak
10import numpy as np
11from pathlib import Path
12from torch.utils.data import DataLoader
13
14from smartBKG.models.gatgap import GATGAPModel
15from smartBKG.utils.dataset import ArrayDataset
16from smartBKG import MODEL_CONFIG
17from smartBKG.utils.metrics import speedup
18
19file_path = Path("./")
20device = torch.device("cpu")
21
22
23def load_events(
24 filenames,
25 balenced=True,
26 max_events=None,
27):
28 """
29 Load events from Parquet files.
30
31 Args:
32 filenames (list): List of file paths.
33 balanced (bool): Whether to balance the number of pass and fail events.
34 Should be set to `True` for training set and `False` for test set.
35 max_events (int): Maximum number of events to load.
36
37 Returns:
38 tuple: Tuple containing arrays and meta information.
39 """
40 arrays = []
41 info = {
42 "total_udst": 0,
43 "total_udst_fail": 0,
44 "total_loaded_pass": 0,
45 "total_loaded_fail": 0
46 }
47 for filename in filenames:
48 ar = ak.from_parquet(filename)
49 n_udst_pass = ak.sum(ar.label)
50 n_udst_fail = ak.sum(~ar.label)
51 n_total = n_udst_pass + n_udst_fail
52 info["total_udst"] += n_total
53 info["total_udst_fail"] += n_udst_fail
54
55 if balenced:
56 ar = ak.packed(
57 ak.concatenate(
58 [
59 ar[ar.label],
60 ar[~ar.label][:n_udst_pass]
61 ]
62 )
63 )
64 info["total_loaded_pass"] += n_udst_pass
65 info["total_loaded_fail"] += len(ar) - n_udst_pass
66
67 arrays.append(ar)
68 if max_events:
69 if len(ar) >= max_events:
70 break
71 return ak.with_field(ak.partitioned(arrays), True, "is_udst"), info
72
73
74def get_loss(pred, label, retention_rate):
75 """
76 Compute the loss with retention rate applied. Can be extended for other losses.
77
78 Args:
79 pred (list): List of file paths.
80 label (bool): True labels (ground truth).
81 retention_rate (float): The rate at which events are retained by the filter.
82
83 Returns:
84 float: The speedup loss achieved by the filtering method.
85 """
86 return speedup(
87 filter_prob=pred,
88 y_true=label,
89 retention_rate=retention_rate
90 )
91
92
93def fit(model, name, ds_train, ds_val, retention_rate, device=device,
94 min_epochs=1, patience=12, lr_start=1e-3, lr_end=1e-4, epochs=1000):
95 """
96 Train the model with dynamic learning rate.
97
98 Args:
99 model (torch.nn.Module): The model used for the training.
100 name (str): File name to save the model.
101 ds_train (list): Dataset created with ArrayDataset for training.
102 ds_val (list): Dataset created with ArrayDataset for validation.
103 retention_rate (float): The rate at which events are retained by the filter.
104 device (torch.device): The place to save dataset and model (`cpu` or `cuda`) during processing.
105 min_epochs (int): The minimal number of epochs for the training.
106 patience (int): The maximal number of continious epochs allowed to have no update.
107 lr_start (float): The learning rate at the beginning of the training.
108 lr_end (float): The minimal learning rate for the training. The strategy of dynamic learning rate
109 is hard-coded with ReduceLROnPlateau scheduler and factor 1/2. This can be manually changed according
110 to the result of hyper-parameter fine-tuning.
111 epochs (int): The maximal number of epochs for the training.
112
113 Returns:
114 dict: The hitories of loss and accuracy for training and validation during the training.
115 """
116 history = {"loss": [], "val_loss": [], "acc": [], "val_acc": []}
117 lr = lr_start
118 opt = torch.optim.Adam(model.parameters(), lr=lr)
119 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=1/2, patience=patience, min_lr=lr_end)
120 nan_count = 0
121 for epoch in range(epochs):
122 updated = False
123 current_lr = [par["lr"] for par in opt.param_groups][0]
124 # training
125 dl_train = DataLoader(ds_train, collate_fn=lambda x: x[0], num_workers=0)
126 train_losses = []
127 train_acces = []
128 for x, y in dl_train:
129 logits = model(x.to(device))
130 labels = y.to(device)
131 loss = get_loss(logits, labels, retention_rate)
132 acc = ((logits > 0) == labels).float().mean()
133 opt.zero_grad()
134 loss.backward()
135 opt.step()
136 train_losses.append(float(loss.detach().cpu().numpy()))
137 train_acces.append(float(acc.detach().cpu().numpy()))
138 # change the seed such that we get newly shuffled batches the next iteration
139 ds_train.seed += 1
140
141 # validation
142 with torch.no_grad():
143 dl_val = DataLoader(ds_val, collate_fn=lambda x: x[0], num_workers=0)
144 val_losses = []
145 val_acces = []
146 for x, y in dl_val:
147 logits = model(x.to(device))
148 labels = y.to(device)
149 loss = get_loss(logits, labels, retention_rate)
150 acc = ((logits > 0) == labels).float().mean()
151 val_losses.append(float(loss.detach().cpu().numpy()))
152 val_acces.append(float(acc.detach().cpu().numpy()))
153
154 history["val_loss"].append(np.mean(val_losses))
155 history["loss"].append(np.mean(train_losses))
156 history["val_acc"].append(np.mean(val_acces))
157 history["acc"].append(np.mean(train_acces))
158
159 min_idx = np.argmin(history["val_loss"])
160 if torch.isnan(loss.detach()):
161 model.load_state_dict(torch.load(name)['model_state_dict'])
162 nan_count += 1
163 if nan_count == int(patience/2):
164 break
165 else:
166 scheduler.step(history["val_loss"][-1])
167 print(f'Nan detected, retry {nan_count}/{int(patience/2)}, lr {current_lr}')
168 continue
169 if min_idx == (len(history["val_loss"]) - 1):
170 updated = True
171 torch.save(
172 {'model_state_dict': model.state_dict()},
173 name
174 )
175 scheduler.step(history["val_loss"][-1])
176 if [par["lr"] for par in opt.param_groups][0] < current_lr:
177 model.load_state_dict(torch.load(name)['model_state_dict'])
178
179 print(f'epoch {epoch}, updated: {updated}, patience: {epoch-min_idx}/{patience}, lr: {current_lr}')
180 print(f'train loss {history["loss"][-1]}, acc {history["acc"][-1]}')
181 print(f'val loss {history["val_loss"][-1]}, acc {history["val_acc"][-1]}')
182
183 if epoch < min_epochs:
184 continue
185
186 if ((patience <= epoch - min_idx) & (current_lr/2 < lr_end)) or (3*patience <= epoch - min_idx):
187 model.load_state_dict(torch.load(name)['model_state_dict'])
188 break
189 return history
190
191
192def test(model, ds_test, device=device):
193 """
194 Test the trained model.
195
196 Args:
197 model (torch.nn.Module): The model used for the training.
198 ds_test (list): Dataset created with ArrayDataset for test.
199 device (torch.device): The place to save dataset and model (`cpu` or `cuda`) during processing.
200
201 Returns:
202 dict: The test performance in terms of accuracy and the list of labels as well as predictions
203 for all the events in the dataset.
204 """
205 record = {"acc": [], "label": [], "pred": []}
206 with torch.no_grad():
207 dl_test = DataLoader(ds_test, collate_fn=lambda x: x[0], num_workers=0)
208 acces = []
209 truth = []
210 preds = []
211 for x, y in dl_test:
212 logits = model(x.to(device))
213 labels = y.to(device)
214 acc = ((logits > 0) == labels).float().mean()
215 acces.append(float(acc.detach().cpu().numpy()))
216 truth.extend(labels.cpu().numpy())
217 preds.extend(torch.sigmoid(logits).detach().cpu().numpy())
218 record["acc"] = np.mean(acces)
219 record["label"] = np.concatenate(truth)
220 record["pred"] = np.concatenate(preds)
221 return record
222
223
224if __name__ == "__main__":
225 # training process
226 train_udst, info = load_events(
227 filenames=file_path.glob('preprocessed*.parquet'),
228 balenced=True
229 )
230 retention_rate = info["total_loaded_pass"] / info["total_udst"]
231 max_events = info["total_loaded_pass"]+info["total_loaded_fail"]
232 train_val_ratio = 0.8
233
234 # the train_val_ratio and batch_size should be adjusted in case of small dataset
235 # to make sure that there is no batch containing only pass/fail events
236 ds_train = ArrayDataset(train_udst[:int(train_val_ratio*max_events)], batch_size=128, shuffle=True)
237 ds_val = ArrayDataset(train_udst[int(train_val_ratio*max_events):], batch_size=1024, shuffle=True)
238
239 model_name = "GATGAP.pth"
240 model = GATGAPModel(**MODEL_CONFIG).to(device)
241 history = fit(model, model_name, ds_train, ds_val, retention_rate=retention_rate)
242
243 # test process
244 test_udst, info = load_events(
245 filenames=file_path.glob('preprocessed*.parquet'),
246 balenced=False
247 )
248 ds_test = ArrayDataset(test_udst, batch_size=1024, shuffle=False)
249 record = test(model, ds_test)