|
13 | 13 | from lightning.pytorch.callbacks import ModelCheckpoint |
14 | 14 | from lightning.pytorch.loggers import Logger |
15 | 15 | from lightning.pytorch.loggers.wandb import WandbLogger |
16 | | -from omegaconf import DictConfig, OmegaConf |
| 16 | +from omegaconf import DictConfig |
17 | 17 |
|
18 | 18 | from topobench.data.preprocessor import PreProcessor |
19 | 19 | from topobench.dataloader import TBDataloader |
|
26 | 26 | log_hyperparameters, |
27 | 27 | task_wrapper, |
28 | 28 | ) |
29 | | -from topobench.utils.config_resolvers import ( |
30 | | - define_task_level, |
31 | | - get_default_metrics, |
32 | | - get_default_trainer, |
33 | | - get_default_transform, |
34 | | - get_flattened_channels, |
35 | | - get_monitor_metric, |
36 | | - get_monitor_mode, |
37 | | - get_non_relational_out_channels, |
38 | | - get_required_lifting, |
39 | | - infer_in_channels, |
40 | | - infer_num_cell_dimensions, |
41 | | - infer_topotune_num_cell_dimensions, |
42 | | -) |
| 29 | +from topobench.utils.config_resolvers import register_all_resolvers |
43 | 30 |
|
44 | 31 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
45 | 32 | # ------------------------------------------------------------------------------------ # |
|
60 | 47 | # ------------------------------------------------------------------------------------ # |
61 | 48 |
|
62 | 49 |
|
63 | | -OmegaConf.register_new_resolver( |
64 | | - "define_task_level", define_task_level, replace=True |
65 | | -) |
66 | | -OmegaConf.register_new_resolver( |
67 | | - "get_default_metrics", get_default_metrics, replace=True |
68 | | -) |
69 | | -OmegaConf.register_new_resolver( |
70 | | - "get_default_trainer", get_default_trainer, replace=True |
71 | | -) |
72 | | -OmegaConf.register_new_resolver( |
73 | | - "get_default_transform", get_default_transform, replace=True |
74 | | -) |
75 | | -OmegaConf.register_new_resolver( |
76 | | - "get_flattened_channels", |
77 | | - get_flattened_channels, |
78 | | - replace=True, |
79 | | -) |
80 | | -OmegaConf.register_new_resolver( |
81 | | - "get_required_lifting", get_required_lifting, replace=True |
82 | | -) |
83 | | -OmegaConf.register_new_resolver( |
84 | | - "get_monitor_metric", get_monitor_metric, replace=True |
85 | | -) |
86 | | -OmegaConf.register_new_resolver( |
87 | | - "get_monitor_mode", get_monitor_mode, replace=True |
88 | | -) |
89 | | -OmegaConf.register_new_resolver( |
90 | | - "get_non_relational_out_channels", |
91 | | - get_non_relational_out_channels, |
92 | | - replace=True, |
93 | | -) |
94 | | -OmegaConf.register_new_resolver( |
95 | | - "infer_in_channels", infer_in_channels, replace=True |
96 | | -) |
97 | | -OmegaConf.register_new_resolver( |
98 | | - "infer_num_cell_dimensions", infer_num_cell_dimensions, replace=True |
99 | | -) |
100 | | -OmegaConf.register_new_resolver( |
101 | | - "infer_topotune_num_cell_dimensions", |
102 | | - infer_topotune_num_cell_dimensions, |
103 | | - replace=True, |
104 | | -) |
105 | | -OmegaConf.register_new_resolver( |
106 | | - "parameter_multiplication", lambda x, y: int(int(x) * int(y)), replace=True |
107 | | -) |
| 50 | +# Register custom resolvers before Hydra initialization |
| 51 | +register_all_resolvers() |
108 | 52 |
|
109 | 53 |
|
110 | 54 | def initialize_hydra() -> DictConfig: |
|
0 commit comments