Skip to content

Commit 9973d18

Browse files
committed
Move from hydra to lerna
Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com>
1 parent 459ac14 commit 9973d18

5 files changed

Lines changed: 60 additions & 15 deletions

File tree

ccflow/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,10 @@ def create_config_from_path(
588588
Returns:
589589
The instance of the model registry, with the configs loaded.
590590
"""
591-
import hydra # Heavy import, only import if used.
591+
try:
592+
import lerna as hydra
593+
except ImportError:
594+
import hydra # Heavy import, only import if used.
592595

593596
overrides = overrides or []
594597
path = pathlib.Path(path).absolute() # Hydra requires absolute paths
@@ -675,8 +678,12 @@ def load_config(self, cfg: DictConfig, registry: ModelRegistry, skip_exceptions:
675678
# This also allows for nested attributes on the model itself to
676679
# be constructed, even if they are not themselves of BaseModel type,
677680
# or if they are of a specific subclass of the parent.
678-
from hydra.errors import InstantiationException
679-
from hydra.utils import instantiate
681+
try:
682+
from lerna.errors import InstantiationException
683+
from lerna.utils import instantiate
684+
except ImportError:
685+
from hydra.errors import InstantiationException
686+
from hydra.utils import instantiate
680687

681688
models_to_register = self._make_subregistries(cfg, [registry])
682689
while True:

ccflow/tests/test_base_registry.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
from unittest import TestCase
77

88
import pytest
9-
from hydra.errors import InstantiationException
9+
10+
try:
11+
from lerna.errors import InstantiationException
12+
except ImportError:
13+
from hydra.errors import InstantiationException
1014
from omegaconf import OmegaConf
1115
from omegaconf.errors import InterpolationKeyError
1216
from pydantic import ConfigDict

ccflow/tests/utils/test_hydra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def test_debug(basepath):
338338
assert "hydra/job_logging" in result.group_options
339339
assert len(result.group_options["hydra/job_logging"]) > 1
340340
assert "config_user" in result.group_options
341-
assert result.group_options["config_user"] == ["sample"]
341+
assert "sample" in result.group_options["config_user"]
342342
# Arguable whether these should be here
343343
assert "conf_out_of_order" in result.group_options[""]
344344

@@ -347,7 +347,7 @@ def test_debug(basepath):
347347
assert merged
348348
assert "foo" in merged
349349
assert "config_user" in merged
350-
assert merged["config_user"]["__options__"] == ["sample"]
350+
assert "sample" in merged["config_user"]["__options__"]
351351
assert merged["config_user"]["__parent__"] == "conf" # Maybe this should be a path to a file
352352
assert merged["config_user"]["__selected__"] == "sample"
353353
assert "user_foo" in merged["config_user"]

ccflow/utils/hydra.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import argparse
22
import inspect
33
import os
4+
import sys
45
from dataclasses import dataclass
56
from logging import getLogger
67
from pathlib import Path
78
from pprint import pprint
89
from textwrap import dedent
910
from typing import Any, Callable, Dict, List, Optional
1011

11-
from hydra._internal.defaults_list import DefaultsList
12+
try:
13+
from lerna._internal.defaults_list import DefaultsList
14+
except ImportError:
15+
from hydra._internal.defaults_list import DefaultsList
1216
from omegaconf import DictConfig, ListConfig, OmegaConf
1317

1418
try:
@@ -115,7 +119,10 @@ def _find_group_options(config_loader, path, config_name, overrides, results):
115119
Note that it will pick up config files that are not intended to be used as config group options,
116120
but that exist to provide common config options to other files in the group (i.e. to default)
117121
"""
118-
from hydra.core.object_type import ObjectType
122+
try:
123+
from lerna.core.object_type import ObjectType
124+
except ImportError:
125+
from hydra.core.object_type import ObjectType
119126

120127
groups = config_loader.get_group_options(path, ObjectType.GROUP, config_name, overrides)
121128
options = config_loader.get_group_options(path, ObjectType.CONFIG, config_name, overrides)
@@ -184,9 +191,10 @@ def load_config(
184191
debug: (Experimental) Whether to enable debug mode. This will return more information about the configs on ConfigLoadResult.
185192
"""
186193
# Heavy import, only import if used
187-
import os
188-
189-
from hydra import compose, initialize_config_dir
194+
try:
195+
from lerna import compose, initialize_config_dir
196+
except ImportError:
197+
from hydra import compose, initialize_config_dir
190198

191199
if return_hydra_config and debug:
192200
raise ValueError("Cannot return hydra config and debug=True at the same time. Please set return_hydra_config=False.")
@@ -209,22 +217,47 @@ def load_config(
209217
result = ConfigLoadResult(root_config_dir=root_config_dir, root_config_name=root_config_name, cfg=cfg)
210218
if debug:
211219
import yaml
212-
from hydra.core.global_hydra import GlobalHydra
213-
from hydra.types import RunMode
220+
221+
try:
222+
from lerna.core.global_hydra import GlobalHydra
223+
from lerna.types import RunMode
224+
except ImportError:
225+
from hydra.core.global_hydra import GlobalHydra
226+
from hydra.types import RunMode
214227

215228
# To track the source file for each config value, we need to monkey patch the yaml loader
216229
original_yaml_load = yaml.load
230+
# Lerna may use a Rust YAML parser that bypasses yaml.load entirely.
231+
# Temporarily disable it so our monkey patch can intercept all YAML loading.
232+
_rust_patches = {}
233+
try:
234+
for _mod_name in (
235+
"lerna._internal.core_plugins.file_config_source",
236+
"lerna._internal.core_plugins.importlib_resources_config_source",
237+
):
238+
_mod = sys.modules.get(_mod_name)
239+
if _mod and getattr(_mod, "_RUST_AVAILABLE", False):
240+
_rust_patches[_mod] = True
241+
_mod._RUST_AVAILABLE = False
242+
except Exception:
243+
pass
244+
217245
try:
218246

219247
def yaml_load(*args, **kwargs):
220248
res = original_yaml_load(*args, **kwargs)
221-
return _dict_add_source(res, args[0].name)
249+
# hydra passes file objects (with .name) to yaml.load;
250+
# lerna passes strings. Use "unknown" as fallback source.
251+
source = getattr(args[0], "name", "unknown") if args else "unknown"
252+
return _dict_add_source(res, source)
222253

223254
yaml.load = yaml_load
224255
# We can't load the hydra config after monkey patching yaml loading, so skip that step
225256
result.cfg_sources = compose(config_name=root_config_name, overrides=overrides, return_hydra_config=False)
226257
finally:
227258
yaml.load = original_yaml_load
259+
for _mod, _val in _rust_patches.items():
260+
_mod._RUST_AVAILABLE = _val
228261

229262
config_loader = GlobalHydra.instance().config_loader()
230263
# Load defaults list using the standard hydra function

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ dependencies = [
4141
"cloudpickle",
4242
"dask",
4343
"deprecated",
44-
"hydra-core",
4544
"IPython",
4645
"jinja2",
46+
"lerna",
4747
"narwhals",
4848
"numpy<3",
4949
"orjson",
@@ -87,6 +87,7 @@ develop = [
8787
"cexprtk",
8888
"csp>=0.8.0,<1",
8989
"duckdb",
90+
"hydra-core",
9091
"pandas",
9192
"panel",
9293
"panel_material_ui",

0 commit comments

Comments
 (0)