diff --git a/src/art/__init__.py b/src/art/__init__.py index 8e494e6c4..6d8d62274 100644 --- a/src/art/__init__.py +++ b/src/art/__init__.py @@ -17,9 +17,13 @@ """ import os +from typing import Any, TYPE_CHECKING from dotenv import load_dotenv +if TYPE_CHECKING: + from .local import LocalBackend + load_dotenv() if os.getenv("SUPPRESS_LITELLM_SERIALIZATION_WARNINGS", "1") == "1": @@ -88,6 +92,17 @@ def __init__(self, **kwargs): from .utils import retry from .yield_trajectory import capture_yielded_trajectory, yield_trajectory + +def __getattr__(name: str) -> Any: + if name == "LocalBackend": + # Keep backend-only dependencies optional until the symbol is requested. + from .local import LocalBackend + + globals()[name] = LocalBackend + return LocalBackend + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ "dev", "auto_trajectory", diff --git a/tests/unit/test_package_exports.py b/tests/unit/test_package_exports.py new file mode 100644 index 000000000..e5a3a2aac --- /dev/null +++ b/tests/unit/test_package_exports.py @@ -0,0 +1,7 @@ +import art + + +def test_art_localbackend_top_level_export(): + from art.local import LocalBackend + + assert art.LocalBackend is LocalBackend