11from pathlib
import Path
14from smartBKG.models.gatgap
import GATGAPModel
15from smartBKG.utils.dataset
import ArrayDataset
16from smartBKG
import MODEL_CONFIG
17from smartBKG.utils.metrics
import speedup
20device = torch.device(
"cpu")
29 Load events from Parquet files.
32 filenames (list): List of file paths.
33 balanced (bool): Whether to balance the number of
pass and fail events.
34 Should be set to `
True`
for training set
and `
False`
for test set.
35 max_events (int): Maximum number of events to load.
38 tuple: Tuple containing arrays
and meta information.
44 "total_loaded_pass": 0,
45 "total_loaded_fail": 0
47 for filename
in filenames:
48 ar = ak.from_parquet(filename)
49 n_udst_pass = ak.sum(ar.label)
50 n_udst_fail = ak.sum(~ar.label)
51 n_total = n_udst_pass + n_udst_fail
52 info[
"total_udst"] += n_total
53 info[
"total_udst_fail"] += n_udst_fail
60 ar[~ar.label][:n_udst_pass]
64 info[
"total_loaded_pass"] += n_udst_pass
65 info[
"total_loaded_fail"] += len(ar) - n_udst_pass
69 if len(ar) >= max_events:
71 return ak.with_field(ak.partitioned(arrays),
True,
"is_udst"), info
74def get_loss(pred, label, retention_rate):
76 Compute the loss with retention rate applied. Can be extended
for other losses.
79 pred (list): List of file paths.
80 label (bool):
True labels (ground truth).
81 retention_rate (float): The rate at which events are retained by the filter.
84 float: The speedup loss achieved by the filtering method.
89 retention_rate=retention_rate
93def fit(model, name, ds_train, ds_val, retention_rate, device=device,
94 min_epochs=1, patience=12, lr_start=1e-3, lr_end=1e-4, epochs=1000):
96 Train the model with dynamic learning rate.
99 model (torch.nn.Module): The model used
for the training.
100 name (str): File name to save the model.
101 ds_train (list): Dataset created
with ArrayDataset
for training.
102 ds_val (list): Dataset created
with ArrayDataset
for validation.
103 retention_rate (float): The rate at which events are retained by the filter.
104 device (torch.device): The place to save dataset
and model (`cpu`
or `cuda`) during processing.
105 min_epochs (int): The minimal number of epochs
for the training.
106 patience (int): The maximal number of continious epochs allowed to have no update.
107 lr_start (float): The learning rate at the beginning of the training.
108 lr_end (float): The minimal learning rate
for the training. The strategy of dynamic learning rate
109 is hard-coded
with ReduceLROnPlateau scheduler
and factor 1/2. This can be manually changed according
110 to the result of hyper-parameter fine-tuning.
111 epochs (int): The maximal number of epochs
for the training.
114 dict: The hitories of loss
and accuracy
for training
and validation during the training.
116 history = {"loss": [],
"val_loss": [],
"acc": [],
"val_acc": []}
118 opt = torch.optim.Adam(model.parameters(), lr=lr)
119 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=1/2, patience=patience, min_lr=lr_end)
121 for epoch
in range(epochs):
123 current_lr = [par[
"lr"]
for par
in opt.param_groups][0]
125 dl_train = DataLoader(ds_train, collate_fn=
lambda x: x[0], num_workers=0)
128 for x, y
in dl_train:
129 logits = model(x.to(device))
130 labels = y.to(device)
131 loss = get_loss(logits, labels, retention_rate)
132 acc = ((logits > 0) == labels).float().mean()
136 train_losses.append(float(loss.detach().cpu().numpy()))
137 train_acces.append(float(acc.detach().cpu().numpy()))
142 with torch.no_grad():
143 dl_val = DataLoader(ds_val, collate_fn=
lambda x: x[0], num_workers=0)
147 logits = model(x.to(device))
148 labels = y.to(device)
149 loss = get_loss(logits, labels, retention_rate)
150 acc = ((logits > 0) == labels).float().mean()
151 val_losses.append(float(loss.detach().cpu().numpy()))
152 val_acces.append(float(acc.detach().cpu().numpy()))
154 history[
"val_loss"].append(np.mean(val_losses))
155 history[
"loss"].append(np.mean(train_losses))
156 history[
"val_acc"].append(np.mean(val_acces))
157 history[
"acc"].append(np.mean(train_acces))
159 min_idx = np.argmin(history[
"val_loss"])
160 if torch.isnan(loss.detach()):
161 model.load_state_dict(torch.load(name)[
'model_state_dict'])
163 if nan_count == int(patience/2):
166 scheduler.step(history[
"val_loss"][-1])
167 print(f
'Nan detected, retry {nan_count}/{int(patience/2)}, lr {current_lr}')
169 if min_idx == (len(history[
"val_loss"]) - 1):
172 {
'model_state_dict': model.state_dict()},
175 scheduler.step(history[
"val_loss"][-1])
176 if [par[
"lr"]
for par
in opt.param_groups][0] < current_lr:
177 model.load_state_dict(torch.load(name)[
'model_state_dict'])
179 print(f
'epoch {epoch}, updated: {updated}, patience: {epoch-min_idx}/{patience}, lr: {current_lr}')
180 print(f
'train loss {history["loss"][-1]}, acc {history["acc"][-1]}')
181 print(f
'val loss {history["val_loss"][-1]}, acc {history["val_acc"][-1]}')
183 if epoch < min_epochs:
186 if ((patience <= epoch - min_idx) & (current_lr/2 < lr_end))
or (3*patience <= epoch - min_idx):
187 model.load_state_dict(torch.load(name)[
'model_state_dict'])
192def test(model, ds_test, device=device):
194 Test the trained model.
197 model (torch.nn.Module): The model used for the training.
198 ds_test (list): Dataset created
with ArrayDataset
for test.
199 device (torch.device): The place to save dataset
and model (`cpu`
or `cuda`) during processing.
202 dict: The test performance
in terms of accuracy
and the list of labels
as well
as predictions
203 for all the events
in the dataset.
205 record = {"acc": [],
"label": [],
"pred": []}
206 with torch.no_grad():
207 dl_test = DataLoader(ds_test, collate_fn=
lambda x: x[0], num_workers=0)
212 logits = model(x.to(device))
213 labels = y.to(device)
214 acc = ((logits > 0) == labels).float().mean()
215 acces.append(float(acc.detach().cpu().numpy()))
216 truth.extend(labels.cpu().numpy())
217 preds.extend(torch.sigmoid(logits).detach().cpu().numpy())
218 record[
"acc"] = np.mean(acces)
219 record[
"label"] = np.concatenate(truth)
220 record[
"pred"] = np.concatenate(preds)
224if __name__ ==
"__main__":
226 train_udst, info = load_events(
227 filenames=file_path.glob(
'preprocessed*.parquet'),
230 retention_rate = info[
"total_loaded_pass"] / info[
"total_udst"]
231 max_events = info[
"total_loaded_pass"]+info[
"total_loaded_fail"]
232 train_val_ratio = 0.8
236 ds_train = ArrayDataset(train_udst[:int(train_val_ratio*max_events)], batch_size=128, shuffle=
True)
237 ds_val = ArrayDataset(train_udst[int(train_val_ratio*max_events):], batch_size=1024, shuffle=
True)
239 model_name =
"GATGAP.pth"
240 model = GATGAPModel(**MODEL_CONFIG).to(device)
241 history = fit(model, model_name, ds_train, ds_val, retention_rate=retention_rate)
244 test_udst, info = load_events(
245 filenames=file_path.glob(
'preprocessed*.parquet'),
248 ds_test = ArrayDataset(test_udst, batch_size=1024, shuffle=
False)
249 record = test(model, ds_test)