Skip to content

Commit 9bfec2b

Browse files
authored
Merge pull request #314 from geometric-intelligence/guille/clean_run.py
Clean run.py
2 parents f905814 + c75fa2e commit 9bfec2b

2 files changed

Lines changed: 58 additions & 60 deletions

File tree

topobench/run.py

Lines changed: 4 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from lightning.pytorch.callbacks import ModelCheckpoint
1414
from lightning.pytorch.loggers import Logger
1515
from lightning.pytorch.loggers.wandb import WandbLogger
16-
from omegaconf import DictConfig, OmegaConf
16+
from omegaconf import DictConfig
1717

1818
from topobench.data.preprocessor import PreProcessor
1919
from topobench.dataloader import TBDataloader
@@ -26,20 +26,7 @@
2626
log_hyperparameters,
2727
task_wrapper,
2828
)
29-
from topobench.utils.config_resolvers import (
30-
define_task_level,
31-
get_default_metrics,
32-
get_default_trainer,
33-
get_default_transform,
34-
get_flattened_channels,
35-
get_monitor_metric,
36-
get_monitor_mode,
37-
get_non_relational_out_channels,
38-
get_required_lifting,
39-
infer_in_channels,
40-
infer_num_cell_dimensions,
41-
infer_topotune_num_cell_dimensions,
42-
)
29+
from topobench.utils.config_resolvers import register_all_resolvers
4330

4431
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
4532
# ------------------------------------------------------------------------------------ #
@@ -60,51 +47,8 @@
6047
# ------------------------------------------------------------------------------------ #
6148

6249

63-
OmegaConf.register_new_resolver(
64-
"define_task_level", define_task_level, replace=True
65-
)
66-
OmegaConf.register_new_resolver(
67-
"get_default_metrics", get_default_metrics, replace=True
68-
)
69-
OmegaConf.register_new_resolver(
70-
"get_default_trainer", get_default_trainer, replace=True
71-
)
72-
OmegaConf.register_new_resolver(
73-
"get_default_transform", get_default_transform, replace=True
74-
)
75-
OmegaConf.register_new_resolver(
76-
"get_flattened_channels",
77-
get_flattened_channels,
78-
replace=True,
79-
)
80-
OmegaConf.register_new_resolver(
81-
"get_required_lifting", get_required_lifting, replace=True
82-
)
83-
OmegaConf.register_new_resolver(
84-
"get_monitor_metric", get_monitor_metric, replace=True
85-
)
86-
OmegaConf.register_new_resolver(
87-
"get_monitor_mode", get_monitor_mode, replace=True
88-
)
89-
OmegaConf.register_new_resolver(
90-
"get_non_relational_out_channels",
91-
get_non_relational_out_channels,
92-
replace=True,
93-
)
94-
OmegaConf.register_new_resolver(
95-
"infer_in_channels", infer_in_channels, replace=True
96-
)
97-
OmegaConf.register_new_resolver(
98-
"infer_num_cell_dimensions", infer_num_cell_dimensions, replace=True
99-
)
100-
OmegaConf.register_new_resolver(
101-
"infer_topotune_num_cell_dimensions",
102-
infer_topotune_num_cell_dimensions,
103-
replace=True,
104-
)
105-
OmegaConf.register_new_resolver(
106-
"parameter_multiplication", lambda x, y: int(int(x) * int(y)), replace=True
107-
)
50+
# Register custom resolvers before Hydra initialization
51+
register_all_resolvers()
10852

10953

11054
def initialize_hydra() -> DictConfig:

topobench/utils/config_resolvers.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,60 @@
44

55
import omegaconf
66
import torch
7+
from omegaconf import OmegaConf
8+
9+
10+
def register_all_resolvers():
11+
"""Register all custom OmegaConf resolvers.
12+
13+
This centralizes resolver registration to avoid duplication across modules. Should be called
14+
before Hydra initialization in any script that uses configs.
15+
"""
16+
OmegaConf.register_new_resolver(
17+
"define_task_level", define_task_level, replace=True
18+
)
19+
OmegaConf.register_new_resolver(
20+
"get_default_metrics", get_default_metrics, replace=True
21+
)
22+
OmegaConf.register_new_resolver(
23+
"get_default_trainer", get_default_trainer, replace=True
24+
)
25+
OmegaConf.register_new_resolver(
26+
"get_default_transform", get_default_transform, replace=True
27+
)
28+
OmegaConf.register_new_resolver(
29+
"get_flattened_channels",
30+
get_flattened_channels,
31+
replace=True,
32+
)
33+
OmegaConf.register_new_resolver(
34+
"get_required_lifting", get_required_lifting, replace=True
35+
)
36+
OmegaConf.register_new_resolver(
37+
"get_monitor_metric", get_monitor_metric, replace=True
38+
)
39+
OmegaConf.register_new_resolver(
40+
"get_monitor_mode", get_monitor_mode, replace=True
41+
)
42+
OmegaConf.register_new_resolver(
43+
"get_non_relational_out_channels",
44+
get_non_relational_out_channels,
45+
replace=True,
46+
)
47+
OmegaConf.register_new_resolver(
48+
"infer_in_channels", infer_in_channels, replace=True
49+
)
50+
OmegaConf.register_new_resolver(
51+
"infer_num_cell_dimensions", infer_num_cell_dimensions, replace=True
52+
)
53+
OmegaConf.register_new_resolver(
54+
"infer_topotune_num_cell_dimensions",
55+
infer_topotune_num_cell_dimensions,
56+
replace=True,
57+
)
58+
OmegaConf.register_new_resolver(
59+
"parameter_multiplication", lambda x, y: int(int(x) * int(y)), replace=True
60+
)
761

862

963
def define_task_level(dataset_task_level, learning_setting):

0 commit comments

Comments
 (0)