14from pathlib
import Path
20from grafei.model.config
import load_config
21from grafei.model.dataset_split
import create_dataloader_mode_tags
22from grafei.model.geometric_network
import GraFEIModel
23from grafei.model.multiTrain
import MultiTrainLoss
24from grafei.model.create_trainer
import GraFEIIgniteTrainer
33 type=click.Path(exists=
True),
34 help=
"path to config file",
42 help=
"Number of samples to train on",
44@click.option("--quiet", "log_level", flag_value=logging.WARNING, default=True)
45@click.option("-v", "--verbose", "log_level", flag_value=logging.INFO)
46@click.option("-vv", "--very-verbose", "log_level", flag_value=logging.DEBUG)
56 datefmt=
"%Y-%m-%d %H:%M",
57 format=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
61 device = torch.device(
"cuda" if torch.cuda.is_available()
else "cpu")
62 print(f
"Using {str(device).upper()} device\n")
65 configs, tags = load_config(
66 Path(cfg_path).resolve(),
71 if configs[
"train"][
"seed"]:
72 seed = configs[
"train"][
"seed"]
75 torch.manual_seed(seed)
76 torch.cuda.manual_seed_all(seed)
79 mode_tags = create_dataloader_mode_tags(configs, tags)
82 B_reco = mode_tags[
"Training"][1].B_reco
84 configs[
"model"].update({
"edge_classes": 6
if B_reco
else 7,
"B_reco": B_reco})
88 n_infeatures = mode_tags[
"Training"][1][0].x.shape[-1]
89 e_infeatures = mode_tags[
"Training"][1][0].edge_attr.shape[-1]
90 g_infeatures = mode_tags[
"Training"][1][0].u.shape[-1]
92 base_model = GraFEIModel(
93 nfeat_in_dim=n_infeatures,
94 efeat_in_dim=e_infeatures,
95 gfeat_in_dim=g_infeatures,
99 print(f
"Model: {base_model}\n")
101 f
"Number of model parameters: {sum(p.numel() for p in base_model.parameters() if p.requires_grad)}\n"
103 print(f
"Using LCAS format, max depth of {5 if B_reco else 6} corresponding to {'B' if B_reco else 'Upsilon(4S)'}\n")
106 if float(torch.__version__[0]) >= 2
and configs[
"train"][
"compile_model"]:
107 print(
"Compiling the model!")
108 model = torch.compile(base_model)
116 loss_fn = MultiTrainLoss(
118 reduction=configs[
"model"][
"loss_reduction"],
119 alpha_mass=configs[
"model"][
"alpha_mass"],
123 optimizer = torch.optim.Adam(
125 configs[
"train"][
"learning_rate"],
131 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
134 )
if configs[
"train"][
"lr_scheduler"]
else None
136 grafei_ignite_trainer = GraFEIIgniteTrainer(
142 tags=list(mode_tags.keys()),
148 grafei_ignite_trainer.setup_handlers(
149 cfg_filename=Path(cfg_path).name,
153 grafei_ignite_trainer.trainer.add_event_handler(
154 ig.engine.Events.EPOCH_COMPLETED,
155 grafei_ignite_trainer.log_results,
160 train_steps = configs[
"train"][
"steps"]
if "steps" in configs[
"train"]
else None
161 grafei_ignite_trainer.trainer.run(
162 mode_tags[
"Training"][2],
163 max_epochs=configs[
"train"][
"epochs"],
164 epoch_length=train_steps,
168if __name__ ==
"__main__":