Belle II Software  light-2403-persian
train_model.py
1 #!/usr/bin/env python
2 
3 
10 
11 
12 import logging
13 import sys
14 from pathlib import Path
15 import click
16 import torch
17 import ignite as ig
18 import numpy as np
19 import random
20 from grafei.model.config import load_config
21 from grafei.model.dataset_split import create_dataloader_mode_tags
22 from grafei.model.geometric_network import GraFEIModel
23 from grafei.model.multiTrain import MultiTrainLoss
24 from grafei.model.create_trainer import GraFEIIgniteTrainer
25 
26 
27 @click.command()
28 @click.option( "-c", "--config", "cfg_path", required=True, type=click.Path(exists=True),
29  help="path to config file",
30 )
31 @click.option( "-n", "--n_samples", "n_samples", required=False, type=int, help="Number of samples to train on", )
32 @click.option("--quiet", "log_level", flag_value=logging.WARNING, default=True)
33 @click.option("-v", "--verbose", "log_level", flag_value=logging.INFO)
34 @click.option("-vv", "--very-verbose", "log_level", flag_value=logging.DEBUG)
35 def main(
36  cfg_path: Path,
37  n_samples: int,
38  log_level: int,
39 ):
40  """"""
41  logging.basicConfig(
42  stream=sys.stdout,
43  level=log_level,
44  datefmt="%Y-%m-%d %H:%M",
45  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
46  )
47 
48  # First figure out which device all this is running on
49  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50  print(f"Using {str(device).upper()} device\n")
51 
52  # Load configs
53  configs, tags = load_config(
54  Path(cfg_path).resolve(),
55  samples=n_samples,
56  )
57 
58  # Random seed
59  if configs["train"]["seed"]:
60  seed = configs["train"]["seed"]
61  random.seed(seed)
62  np.random.seed(seed)
63  torch.manual_seed(seed)
64  torch.cuda.manual_seed_all(seed)
65 
66  # Load datasets
67  mode_tags = create_dataloader_mode_tags(configs, tags)
68 
69  # Find out if we are reconstructing B or Ups
70  B_reco = mode_tags["Training"][1].B_reco
71 
72  configs["model"].update({"edge_classes": 6 if B_reco else 7, "B_reco": B_reco})
73 
74  # Now build the model
75  # Extract the number of features, assuming last dim is features
76  n_infeatures = mode_tags["Training"][1][0].x.shape[-1]
77  e_infeatures = mode_tags["Training"][1][0].edge_attr.shape[-1]
78  g_infeatures = mode_tags["Training"][1][0].u.shape[-1]
79 
80  base_model = GraFEIModel(
81  nfeat_in_dim=n_infeatures,
82  efeat_in_dim=e_infeatures,
83  gfeat_in_dim=g_infeatures,
84  **configs["model"],
85  )
86 
87  print(f"Model: {base_model}\n")
88  print(
89  f"Number of model parameters: {sum(p.numel() for p in base_model.parameters() if p.requires_grad)}\n"
90  )
91  print(f"Using LCAS format, max depth of {5 if B_reco else 6} corresponding to {'B' if B_reco else 'Upsilon(4S)'}\n")
92 
93  # Compile the model (requires PyTorch >= 2.0.0)
94  if float(torch.__version__[0]) >= 2 and configs["train"]["compile_model"]:
95  print("Compiling the model!")
96  model = torch.compile(base_model)
97  else:
98  model = base_model
99 
100  # Send the model to specific device
101  model.to(device)
102 
103  # Set the loss
104  loss_fn = MultiTrainLoss(
105  ignore_index=-1,
106  reduction=configs["model"]["loss_reduction"],
107  alpha_mass=configs["model"]["alpha_mass"],
108  )
109 
110  # Set the optimiser
111  optimizer = torch.optim.Adam(
112  model.parameters(),
113  configs["train"]["learning_rate"],
114  weight_decay=0,
115  amsgrad=False,
116  eps=0.001,
117  )
118 
119  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
120  optimizer=optimizer,
121  T_max=10,
122  ) if configs["train"]["lr_scheduler"] else None
123 
124  grafei_ignite_trainer = GraFEIIgniteTrainer(
125  model=model,
126  optimizer=optimizer,
127  loss_fn=loss_fn,
128  device=device,
129  configs=configs,
130  tags=list(mode_tags.keys()),
131  scheduler=scheduler,
132  ignore_index=-1,
133  )
134 
135  # Set up the actual checkpoints and save the configs if requested
136  grafei_ignite_trainer.setup_handlers(
137  cfg_filename=Path(cfg_path).name,
138  )
139 
140  # Add callback to run evaluation after each epoch
141  grafei_ignite_trainer.trainer.add_event_handler(
142  ig.engine.Events.EPOCH_COMPLETED,
143  grafei_ignite_trainer.log_results,
144  mode_tags,
145  )
146 
147  # Actually run the training, mode_tags calls the train_loader
148  train_steps = configs["train"]["steps"] if "steps" in configs["train"] else None
149  grafei_ignite_trainer.trainer.run(
150  mode_tags["Training"][2],
151  max_epochs=configs["train"]["epochs"],
152  epoch_length=train_steps,
153  )
154 
155 
156 if __name__ == "__main__":
157  main()
158 
Definition: main.py:1
int main(int argc, char **argv)
Run all tests.
Definition: test_main.cc:91