Skip to content

Commit 65b291d

Browse files
committed
make config_root so it is logcially independent of recipe
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent ae9e245 commit 65b291d

2 files changed

Lines changed: 18 additions & 10 deletions

File tree

modelopt/recipe/_config_loader.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,18 @@
1616
"""Re-export config loading utilities from ``modelopt.torch.opt.config_loader``."""
1717

1818
from modelopt.torch.opt.config_loader import (
19-
BUILTIN_RECIPES_LIB,
19+
BUILTIN_CONFIG_ROOT,
2020
_load_raw_config,
2121
_resolve_imports,
2222
load_config,
2323
)
2424

25-
__all__ = ["BUILTIN_RECIPES_LIB", "_load_raw_config", "_resolve_imports", "load_config"]
25+
BUILTIN_RECIPES_LIB = BUILTIN_CONFIG_ROOT
26+
27+
__all__ = [
28+
"BUILTIN_CONFIG_ROOT",
29+
"BUILTIN_RECIPES_LIB",
30+
"_load_raw_config",
31+
"_resolve_imports",
32+
"load_config",
33+
]

modelopt/torch/opt/config_loader.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333

3434
import yaml
3535

36-
# Root to all built-in recipes. Users can create own recipes.
37-
BUILTIN_RECIPES_LIB = files("modelopt_recipes")
36+
# Root to all built-in configs and recipes.
37+
BUILTIN_CONFIG_ROOT = files("modelopt_recipes")
3838

3939
_EXMY_RE = re.compile(r"^[Ee](\d+)[Mm](\d+)$")
4040
_EXMY_KEYS = frozenset({"num_bits", "scale_bits"})
@@ -73,22 +73,22 @@ def _load_raw_config(config_file: str | Path | Traversable) -> dict[str, Any] |
7373
if not config_file.endswith(".yml") and not config_file.endswith(".yaml"):
7474
paths_to_check.append(Path(f"{config_file}.yml"))
7575
paths_to_check.append(Path(f"{config_file}.yaml"))
76-
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yml"))
77-
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yaml"))
76+
paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(f"{config_file}.yml"))
77+
paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(f"{config_file}.yaml"))
7878
else:
7979
paths_to_check.append(Path(config_file))
80-
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(config_file))
80+
paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(config_file))
8181
elif isinstance(config_file, Path):
8282
if config_file.suffix in (".yml", ".yaml"):
8383
paths_to_check.append(config_file)
8484
if not config_file.is_absolute():
85-
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(str(config_file)))
85+
paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(str(config_file)))
8686
else:
8787
paths_to_check.append(Path(f"{config_file}.yml"))
8888
paths_to_check.append(Path(f"{config_file}.yaml"))
8989
if not config_file.is_absolute():
90-
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yml"))
91-
paths_to_check.append(BUILTIN_RECIPES_LIB.joinpath(f"{config_file}.yaml"))
90+
paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(f"{config_file}.yml"))
91+
paths_to_check.append(BUILTIN_CONFIG_ROOT.joinpath(f"{config_file}.yaml"))
9292
elif isinstance(config_file, Traversable):
9393
paths_to_check.append(config_file)
9494
else:

0 commit comments

Comments
 (0)