Skip to content

Commit ec016cd

Browse files
authored
fix: unify memory optional dependency import errors (#3389)
1 parent 5645845 commit ec016cd

6 files changed

Lines changed: 232 additions & 104 deletions

File tree

src/agents/extensions/memory/__init__.py

Lines changed: 35 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88

99
from __future__ import annotations
1010

11+
from importlib import import_module
1112
from typing import TYPE_CHECKING, Any
1213

14+
from ._optional_imports import raise_optional_dependency_error
15+
1316
if TYPE_CHECKING:
1417
from .advanced_sqlite_session import AdvancedSQLiteSession
1518
from .async_sqlite_session import AsyncSQLiteSession
@@ -35,99 +38,37 @@
3538
"SQLAlchemySession",
3639
]
3740

41+
_LAZY_EXPORTS: dict[str, tuple[str, tuple[str, str] | None]] = {
42+
"EncryptedSession": (".encrypt_session", ("cryptography", "encrypt")),
43+
"RedisSession": (".redis_session", ("redis", "redis")),
44+
"SQLAlchemySession": (".sqlalchemy_session", ("sqlalchemy", "sqlalchemy")),
45+
"AdvancedSQLiteSession": (".advanced_sqlite_session", None),
46+
"AsyncSQLiteSession": (".async_sqlite_session", None),
47+
"DaprSession": (".dapr_session", ("dapr", "dapr")),
48+
"DAPR_CONSISTENCY_EVENTUAL": (".dapr_session", ("dapr", "dapr")),
49+
"DAPR_CONSISTENCY_STRONG": (".dapr_session", ("dapr", "dapr")),
50+
"MongoDBSession": (".mongodb_session", ("mongodb", "mongodb")),
51+
}
3852

39-
def __getattr__(name: str) -> Any:
40-
if name == "EncryptedSession":
41-
try:
42-
from .encrypt_session import EncryptedSession # noqa: F401
43-
44-
return EncryptedSession
45-
except ModuleNotFoundError as e:
46-
raise ImportError(
47-
"EncryptedSession requires the 'cryptography' extra. "
48-
"Install it with: pip install openai-agents[encrypt]"
49-
) from e
50-
51-
if name == "RedisSession":
52-
try:
53-
from .redis_session import RedisSession # noqa: F401
54-
55-
return RedisSession
56-
except ModuleNotFoundError as e:
57-
raise ImportError(
58-
"RedisSession requires the 'redis' extra. "
59-
"Install it with: pip install openai-agents[redis]"
60-
) from e
61-
62-
if name == "SQLAlchemySession":
63-
try:
64-
from .sqlalchemy_session import SQLAlchemySession # noqa: F401
65-
66-
return SQLAlchemySession
67-
except ModuleNotFoundError as e:
68-
raise ImportError(
69-
"SQLAlchemySession requires the 'sqlalchemy' extra. "
70-
"Install it with: pip install openai-agents[sqlalchemy]"
71-
) from e
72-
73-
if name == "AdvancedSQLiteSession":
74-
try:
75-
from .advanced_sqlite_session import AdvancedSQLiteSession # noqa: F401
76-
77-
return AdvancedSQLiteSession
78-
except ModuleNotFoundError as e:
79-
raise ImportError(f"Failed to import AdvancedSQLiteSession: {e}") from e
80-
81-
if name == "AsyncSQLiteSession":
82-
try:
83-
from .async_sqlite_session import AsyncSQLiteSession # noqa: F401
8453

85-
return AsyncSQLiteSession
86-
except ModuleNotFoundError as e:
87-
raise ImportError(f"Failed to import AsyncSQLiteSession: {e}") from e
88-
89-
if name == "DaprSession":
90-
try:
91-
from .dapr_session import DaprSession # noqa: F401
92-
93-
return DaprSession
94-
except ModuleNotFoundError as e:
95-
raise ImportError(
96-
"DaprSession requires the 'dapr' extra. "
97-
"Install it with: pip install openai-agents[dapr]"
98-
) from e
99-
100-
if name == "DAPR_CONSISTENCY_EVENTUAL":
101-
try:
102-
from .dapr_session import DAPR_CONSISTENCY_EVENTUAL # noqa: F401
103-
104-
return DAPR_CONSISTENCY_EVENTUAL
105-
except ModuleNotFoundError as e:
106-
raise ImportError(
107-
"DAPR_CONSISTENCY_EVENTUAL requires the 'dapr' extra. "
108-
"Install it with: pip install openai-agents[dapr]"
109-
) from e
110-
111-
if name == "DAPR_CONSISTENCY_STRONG":
112-
try:
113-
from .dapr_session import DAPR_CONSISTENCY_STRONG # noqa: F401
114-
115-
return DAPR_CONSISTENCY_STRONG
116-
except ModuleNotFoundError as e:
117-
raise ImportError(
118-
"DAPR_CONSISTENCY_STRONG requires the 'dapr' extra. "
119-
"Install it with: pip install openai-agents[dapr]"
120-
) from e
121-
122-
if name == "MongoDBSession":
123-
try:
124-
from .mongodb_session import MongoDBSession # noqa: F401
125-
126-
return MongoDBSession
127-
except ModuleNotFoundError as e:
128-
raise ImportError(
129-
"MongoDBSession requires the 'mongodb' extra. "
130-
"Install it with: pip install openai-agents[mongodb]"
131-
) from e
132-
133-
raise AttributeError(f"module {__name__} has no attribute {name}")
54+
def __getattr__(name: str) -> Any:
55+
if name not in _LAZY_EXPORTS:
56+
raise AttributeError(f"module {__name__} has no attribute {name}")
57+
58+
module_name, optional_dependency = _LAZY_EXPORTS[name]
59+
try:
60+
module = import_module(module_name, __name__)
61+
except ModuleNotFoundError as e:
62+
if optional_dependency is None:
63+
raise ImportError(f"Failed to import {name}: {e}") from e
64+
dependency_name, extra_name = optional_dependency
65+
raise_optional_dependency_error(
66+
name,
67+
dependency_name=dependency_name,
68+
extra_name=extra_name,
69+
cause=e,
70+
)
71+
72+
value = getattr(module, name)
73+
globals()[name] = value
74+
return value
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from __future__ import annotations
2+
3+
from typing import NoReturn
4+
5+
6+
def raise_optional_dependency_error(
7+
export_name: str,
8+
*,
9+
dependency_name: str,
10+
extra_name: str,
11+
cause: ImportError | None = None,
12+
) -> NoReturn:
13+
error = ImportError(
14+
f"{export_name} requires the '{dependency_name}' extra. "
15+
f"Install it with: pip install openai-agents[{extra_name}]"
16+
)
17+
if cause is None:
18+
raise error
19+
raise error from cause

src/agents/extensions/memory/dapr_session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,18 @@
2929
import time
3030
from typing import Any, Final, Literal
3131

32+
from ._optional_imports import raise_optional_dependency_error
33+
3234
try:
3335
from dapr.aio.clients import DaprClient
3436
from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions
3537
except ImportError as e:
36-
raise ImportError(
37-
"DaprSession requires the 'dapr' package. Install it with: pip install dapr"
38-
) from e
38+
raise_optional_dependency_error(
39+
"DaprSession",
40+
dependency_name="dapr",
41+
extra_name="dapr",
42+
cause=e,
43+
)
3944

4045
from ...items import TResponseInputItem
4146
from ...logger import logger

src/agents/extensions/memory/mongodb_session.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
from datetime import datetime, timezone
3838
from typing import Any
3939

40+
from ._optional_imports import raise_optional_dependency_error
41+
4042
try:
4143
from importlib.metadata import version as _get_version
4244

@@ -49,10 +51,12 @@
4951
from pymongo.asynchronous.mongo_client import AsyncMongoClient
5052
from pymongo.driver_info import DriverInfo
5153
except ImportError as e:
52-
raise ImportError(
53-
"MongoDBSession requires the 'pymongo' package (>=4.14). "
54-
"Install it with: pip install openai-agents[mongodb]"
55-
) from e
54+
raise_optional_dependency_error(
55+
"MongoDBSession",
56+
dependency_name="mongodb",
57+
extra_name="mongodb",
58+
cause=e,
59+
)
5660

5761
from ...items import TResponseInputItem
5862
from ...memory.session import SessionABC

src/agents/extensions/memory/redis_session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,18 @@
2626
import time
2727
from typing import Any
2828

29+
from ._optional_imports import raise_optional_dependency_error
30+
2931
try:
3032
import redis.asyncio as redis
3133
from redis.asyncio import Redis
3234
except ImportError as e:
33-
raise ImportError(
34-
"RedisSession requires the 'redis' package. Install it with: pip install redis"
35-
) from e
35+
raise_optional_dependency_error(
36+
"RedisSession",
37+
dependency_name="redis",
38+
extra_name="redis",
39+
cause=e,
40+
)
3641

3742
from ...items import TResponseInputItem
3843
from ...memory.session import SessionABC
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from __future__ import annotations
2+
3+
import importlib.abc
4+
import sys
5+
from types import ModuleType
6+
7+
import pytest
8+
9+
_PACKAGE_EXPORTS: tuple[tuple[str, str, str, str, str], ...] = (
10+
(
11+
"EncryptedSession",
12+
"agents.extensions.memory.encrypt_session",
13+
"agents.extensions.memory.encrypt_session",
14+
"cryptography",
15+
"encrypt",
16+
),
17+
("RedisSession", "agents.extensions.memory.redis_session", "redis.asyncio", "redis", "redis"),
18+
(
19+
"SQLAlchemySession",
20+
"agents.extensions.memory.sqlalchemy_session",
21+
"agents.extensions.memory.sqlalchemy_session",
22+
"sqlalchemy",
23+
"sqlalchemy",
24+
),
25+
("DaprSession", "agents.extensions.memory.dapr_session", "dapr.aio.clients", "dapr", "dapr"),
26+
(
27+
"DAPR_CONSISTENCY_EVENTUAL",
28+
"agents.extensions.memory.dapr_session",
29+
"dapr.aio.clients",
30+
"dapr",
31+
"dapr",
32+
),
33+
(
34+
"DAPR_CONSISTENCY_STRONG",
35+
"agents.extensions.memory.dapr_session",
36+
"dapr.aio.clients",
37+
"dapr",
38+
"dapr",
39+
),
40+
(
41+
"MongoDBSession",
42+
"agents.extensions.memory.mongodb_session",
43+
"pymongo.asynchronous.collection",
44+
"mongodb",
45+
"mongodb",
46+
),
47+
)
48+
49+
_DIRECT_MODULE_IMPORTS: tuple[tuple[str, str, str, str], ...] = (
50+
("agents.extensions.memory.redis_session", "redis.asyncio", "redis", "redis"),
51+
("agents.extensions.memory.dapr_session", "dapr.aio.clients", "dapr", "dapr"),
52+
(
53+
"agents.extensions.memory.mongodb_session",
54+
"pymongo.asynchronous.collection",
55+
"mongodb",
56+
"mongodb",
57+
),
58+
)
59+
60+
61+
class _BrokenImportFinder(importlib.abc.MetaPathFinder):
62+
def __init__(self, broken_module: str, error_cls: type[ImportError]) -> None:
63+
self._broken_module = broken_module
64+
self._error_cls = error_cls
65+
66+
def find_spec(
67+
self,
68+
fullname: str,
69+
path: object | None,
70+
target: ModuleType | None = None,
71+
) -> None:
72+
if fullname == self._broken_module:
73+
raise self._error_cls("simulated dependency import failure")
74+
return None
75+
76+
77+
def _reset_package_imports(
78+
monkeypatch: pytest.MonkeyPatch,
79+
memory_module: ModuleType,
80+
symbol: str,
81+
module_name: str,
82+
broken_module: str,
83+
) -> None:
84+
monkeypatch.delitem(memory_module.__dict__, symbol, raising=False)
85+
_reset_loaded_module(monkeypatch, module_name)
86+
_reset_loaded_module(monkeypatch, broken_module)
87+
88+
89+
def _reset_loaded_module(monkeypatch: pytest.MonkeyPatch, module_name: str) -> None:
90+
monkeypatch.delitem(sys.modules, module_name, raising=False)
91+
parent_name, short_name = module_name.rsplit(".", 1)
92+
parent_module = sys.modules.get(parent_name)
93+
if parent_module is not None:
94+
monkeypatch.delitem(parent_module.__dict__, short_name, raising=False)
95+
96+
97+
def _reset_module_imports(
98+
monkeypatch: pytest.MonkeyPatch,
99+
module_name: str,
100+
broken_module: str,
101+
) -> None:
102+
_reset_loaded_module(monkeypatch, module_name)
103+
_reset_loaded_module(monkeypatch, broken_module)
104+
105+
106+
@pytest.mark.parametrize(
107+
("symbol", "module_name", "broken_module", "dependency_name", "extra_name"),
108+
_PACKAGE_EXPORTS,
109+
)
110+
def test_memory_package_imports_point_to_optional_extra(
111+
monkeypatch: pytest.MonkeyPatch,
112+
symbol: str,
113+
module_name: str,
114+
broken_module: str,
115+
dependency_name: str,
116+
extra_name: str,
117+
) -> None:
118+
import agents.extensions.memory as memory_module
119+
120+
_reset_package_imports(monkeypatch, memory_module, symbol, module_name, broken_module)
121+
finder = _BrokenImportFinder(broken_module, ModuleNotFoundError)
122+
monkeypatch.setattr(sys, "meta_path", [finder, *sys.meta_path])
123+
124+
with pytest.raises(ImportError) as exc_info:
125+
getattr(memory_module, symbol)
126+
127+
assert f"requires the '{dependency_name}' extra" in str(exc_info.value)
128+
assert f"openai-agents[{extra_name}]" in str(exc_info.value)
129+
assert isinstance(exc_info.value.__cause__, ImportError)
130+
131+
132+
@pytest.mark.parametrize(
133+
("module_name", "broken_module", "dependency_name", "extra_name"),
134+
_DIRECT_MODULE_IMPORTS,
135+
)
136+
@pytest.mark.parametrize("error_cls", [ImportError, ModuleNotFoundError])
137+
def test_memory_direct_module_imports_point_to_optional_extra(
138+
monkeypatch: pytest.MonkeyPatch,
139+
module_name: str,
140+
broken_module: str,
141+
dependency_name: str,
142+
extra_name: str,
143+
error_cls: type[ImportError],
144+
) -> None:
145+
_reset_module_imports(monkeypatch, module_name, broken_module)
146+
finder = _BrokenImportFinder(broken_module, error_cls)
147+
monkeypatch.setattr(sys, "meta_path", [finder, *sys.meta_path])
148+
149+
with pytest.raises(ImportError) as exc_info:
150+
__import__(module_name)
151+
152+
assert f"requires the '{dependency_name}' extra" in str(exc_info.value)
153+
assert f"openai-agents[{extra_name}]" in str(exc_info.value)
154+
assert isinstance(exc_info.value.__cause__, ImportError)

0 commit comments

Comments
 (0)