Belle II Software light-2406-ragdoll
train_model.py
1#!/usr/bin/env python
2
3
10
11
12import logging
13import sys
14from pathlib import Path
15import click
16import torch
17import ignite as ig
18import numpy as np
19import random
20from grafei.model.config import load_config
21from grafei.model.dataset_split import create_dataloader_mode_tags
22from grafei.model.geometric_network import GraFEIModel
23from grafei.model.multiTrain import MultiTrainLoss
24from grafei.model.create_trainer import GraFEIIgniteTrainer
25
26
27@click.command()
28@click.option(
29 "-c",
30 "--config",
31 "cfg_path",
32 required=True,
33 type=click.Path(exists=True),
34 help="path to config file",
35)
36@click.option(
37 "-n",
38 "--n_samples",
39 "n_samples",
40 required=False,
41 type=int,
42 help="Number of samples to train on",
43)
44@click.option("--quiet", "log_level", flag_value=logging.WARNING, default=True)
45@click.option("-v", "--verbose", "log_level", flag_value=logging.INFO)
46@click.option("-vv", "--very-verbose", "log_level", flag_value=logging.DEBUG)
47def main(
48 cfg_path: Path,
49 n_samples: int,
50 log_level: int,
51):
52 """"""
53 logging.basicConfig(
54 stream=sys.stdout,
55 level=log_level,
56 datefmt="%Y-%m-%d %H:%M",
57 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
58 )
59
60 # First figure out which device all this is running on
61 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62 print(f"Using {str(device).upper()} device\n")
63
64 # Load configs
65 configs, tags = load_config(
66 Path(cfg_path).resolve(),
67 samples=n_samples,
68 )
69
70 # Random seed
71 if configs["train"]["seed"]:
72 seed = configs["train"]["seed"]
73 random.seed(seed)
74 np.random.seed(seed)
75 torch.manual_seed(seed)
76 torch.cuda.manual_seed_all(seed)
77
78 # Load datasets
79 mode_tags = create_dataloader_mode_tags(configs, tags)
80
81 # Find out if we are reconstructing B or Ups
82 B_reco = mode_tags["Training"][1].B_reco
83
84 configs["model"].update({"edge_classes": 6 if B_reco else 7, "B_reco": B_reco})
85
86 # Now build the model
87 # Extract the number of features, assuming last dim is features
88 n_infeatures = mode_tags["Training"][1][0].x.shape[-1]
89 e_infeatures = mode_tags["Training"][1][0].edge_attr.shape[-1]
90 g_infeatures = mode_tags["Training"][1][0].u.shape[-1]
91
92 base_model = GraFEIModel(
93 nfeat_in_dim=n_infeatures,
94 efeat_in_dim=e_infeatures,
95 gfeat_in_dim=g_infeatures,
96 **configs["model"],
97 )
98
99 print(f"Model: {base_model}\n")
100 print(
101 f"Number of model parameters: {sum(p.numel() for p in base_model.parameters() if p.requires_grad)}\n"
102 )
103 print(f"Using LCAS format, max depth of {5 if B_reco else 6} corresponding to {'B' if B_reco else 'Upsilon(4S)'}\n")
104
105 # Compile the model (requires PyTorch >= 2.0.0)
106 if float(torch.__version__[0]) >= 2 and configs["train"]["compile_model"]:
107 print("Compiling the model!")
108 model = torch.compile(base_model)
109 else:
110 model = base_model
111
112 # Send the model to specific device
113 model.to(device)
114
115 # Set the loss
116 loss_fn = MultiTrainLoss(
117 ignore_index=-1,
118 reduction=configs["model"]["loss_reduction"],
119 alpha_mass=configs["model"]["alpha_mass"],
120 )
121
122 # Set the optimiser
123 optimizer = torch.optim.Adam(
124 model.parameters(),
125 configs["train"]["learning_rate"],
126 weight_decay=0,
127 amsgrad=False,
128 eps=0.001,
129 )
130
131 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
132 optimizer=optimizer,
133 T_max=10,
134 ) if configs["train"]["lr_scheduler"] else None
135
136 grafei_ignite_trainer = GraFEIIgniteTrainer(
137 model=model,
138 optimizer=optimizer,
139 loss_fn=loss_fn,
140 device=device,
141 configs=configs,
142 tags=list(mode_tags.keys()),
143 scheduler=scheduler,
144 ignore_index=-1,
145 )
146
147 # Set up the actual checkpoints and save the configs if requested
148 grafei_ignite_trainer.setup_handlers(
149 cfg_filename=Path(cfg_path).name,
150 )
151
152 # Add callback to run evaluation after each epoch
153 grafei_ignite_trainer.trainer.add_event_handler(
154 ig.engine.Events.EPOCH_COMPLETED,
155 grafei_ignite_trainer.log_results,
156 mode_tags,
157 )
158
159 # Actually run the training, mode_tags calls the train_loader
160 train_steps = configs["train"]["steps"] if "steps" in configs["train"] else None
161 grafei_ignite_trainer.trainer.run(
162 mode_tags["Training"][2],
163 max_epochs=configs["train"]["epochs"],
164 epoch_length=train_steps,
165 )
166
167
168if __name__ == "__main__":
169 main()
Definition: main.py:1