14 from pathlib
import Path
20 from grafei.model.config
import load_config
21 from grafei.model.dataset_split
import create_dataloader_mode_tags
22 from grafei.model.geometric_network
import GraFEIModel
23 from grafei.model.multiTrain
import MultiTrainLoss
24 from grafei.model.create_trainer
import GraFEIIgniteTrainer
28 @click.option(
"-c",
"--config",
"cfg_path",
required=True,
type=click.Path(exists=True),
29 help=
"path to config file",
31 @click.option(
"-n",
"--n_samples",
"n_samples",
required=False,
type=int,
help="Number of samples to train on",
)
32 @click.option("--quiet", "log_level", flag_value=logging.WARNING, default=True)
33 @click.option("-v", "--verbose", "log_level", flag_value=logging.INFO)
34 @click.option("-vv", "--very-verbose", "log_level", flag_value=logging.DEBUG)
44 datefmt=
"%Y-%m-%d %H:%M",
45 format=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
49 device = torch.device(
"cuda" if torch.cuda.is_available()
else "cpu")
50 print(f
"Using {str(device).upper()} device\n")
53 configs, tags = load_config(
54 Path(cfg_path).resolve(),
59 if configs[
"train"][
"seed"]:
60 seed = configs[
"train"][
"seed"]
63 torch.manual_seed(seed)
64 torch.cuda.manual_seed_all(seed)
67 mode_tags = create_dataloader_mode_tags(configs, tags)
70 B_reco = mode_tags[
"Training"][1].B_reco
72 configs[
"model"].update({
"edge_classes": 6
if B_reco
else 7,
"B_reco": B_reco})
76 n_infeatures = mode_tags[
"Training"][1][0].x.shape[-1]
77 e_infeatures = mode_tags[
"Training"][1][0].edge_attr.shape[-1]
78 g_infeatures = mode_tags[
"Training"][1][0].u.shape[-1]
80 base_model = GraFEIModel(
81 nfeat_in_dim=n_infeatures,
82 efeat_in_dim=e_infeatures,
83 gfeat_in_dim=g_infeatures,
87 print(f
"Model: {base_model}\n")
89 f
"Number of model parameters: {sum(p.numel() for p in base_model.parameters() if p.requires_grad)}\n"
91 print(f
"Using LCAS format, max depth of {5 if B_reco else 6} corresponding to {'B' if B_reco else 'Upsilon(4S)'}\n")
94 if float(torch.__version__[0]) >= 2
and configs[
"train"][
"compile_model"]:
95 print(
"Compiling the model!")
96 model = torch.compile(base_model)
104 loss_fn = MultiTrainLoss(
106 reduction=configs[
"model"][
"loss_reduction"],
107 alpha_mass=configs[
"model"][
"alpha_mass"],
111 optimizer = torch.optim.Adam(
113 configs[
"train"][
"learning_rate"],
119 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
122 )
if configs[
"train"][
"lr_scheduler"]
else None
124 grafei_ignite_trainer = GraFEIIgniteTrainer(
130 tags=list(mode_tags.keys()),
136 grafei_ignite_trainer.setup_handlers(
137 cfg_filename=Path(cfg_path).name,
141 grafei_ignite_trainer.trainer.add_event_handler(
142 ig.engine.Events.EPOCH_COMPLETED,
143 grafei_ignite_trainer.log_results,
148 train_steps = configs[
"train"][
"steps"]
if "steps" in configs[
"train"]
else None
149 grafei_ignite_trainer.trainer.run(
150 mode_tags[
"Training"][2],
151 max_epochs=configs[
"train"][
"epochs"],
152 epoch_length=train_steps,
156 if __name__ ==
"__main__":
158
int main(int argc, char **argv)
Run all tests.