Skip to content

Commit 6dbdcde

Browse files
committed
fix: address typecheck issues in wrapper compatibility changes
1 parent e6472db commit 6dbdcde

3 files changed

Lines changed: 28 additions & 8 deletions

File tree

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ async def add_items(
140140
async def get_items(
141141
self,
142142
limit: int | None = None,
143+
wrapper: RunContextWrapper[Any] | str | None = None,
144+
*,
143145
branch_id: str | None = None,
144-
wrapper: RunContextWrapper[Any] | None = None,
145146
) -> list[TResponseInputItem]:
146147
"""Get items from current or specified branch.
147148
@@ -152,6 +153,10 @@ async def get_items(
152153
Returns:
153154
List of conversation items from the specified branch.
154155
"""
156+
if isinstance(wrapper, str) and branch_id is None:
157+
branch_id = wrapper
158+
wrapper = None
159+
155160
session_limit = resolve_session_limit(limit, self.session_settings)
156161

157162
if branch_id is None:

tests/extensions/memory/test_encrypt_session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
pytest.importorskip("cryptography") # Skip tests if cryptography is not installed
99

10+
from typing import cast
11+
1012
from cryptography.fernet import Fernet
1113

1214
from agents import Agent, Runner, SQLiteSession, TResponseInputItem
@@ -116,7 +118,7 @@ async def test_encrypted_session_preserves_legacy_underlying_signatures(
116118
agent: Agent,
117119
encryption_key: str,
118120
):
119-
class LegacyUnderlyingSession(SessionABC):
121+
class LegacyUnderlyingSession:
120122
def __init__(self) -> None:
121123
self.session_id = "test_session"
122124
self.items: list[TResponseInputItem] = []
@@ -140,7 +142,7 @@ async def clear_session(self) -> None:
140142
legacy_underlying = LegacyUnderlyingSession()
141143
session = EncryptedSession(
142144
session_id="test_session",
143-
underlying_session=legacy_underlying,
145+
underlying_session=cast(SessionABC, legacy_underlying),
144146
encryption_key=encryption_key,
145147
)
146148

tests/test_session.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem
1111
from agents.memory.session import SessionABC
12+
from agents.memory.session_settings import SessionSettings
1213
from agents.run_context import RunContextWrapper
1314
from agents.run_internal.session_persistence import prepare_input_with_session
1415

@@ -136,16 +137,26 @@ async def clear_session(self) -> None:
136137

137138
class DefaultLimitedSession:
138139
session_id = "default-limited"
140+
session_settings: SessionSettings | None = None
139141

140142
def __init__(self) -> None:
141143
self.items: list[TResponseInputItem] = []
142144
self.get_call_count = 0
143145

144-
async def get_items(self, limit: int = 1) -> list[TResponseInputItem]:
146+
async def get_items(
147+
self,
148+
limit: int | None = None,
149+
wrapper: RunContextWrapper[Any] | None = None,
150+
) -> list[TResponseInputItem]:
145151
self.get_call_count += 1
146-
return list(self.items[-limit:]) if limit > 0 else []
152+
effective_limit = 1 if limit is None else limit
153+
return list(self.items[-effective_limit:]) if effective_limit > 0 else []
147154

148-
async def add_items(self, items: list[TResponseInputItem]) -> None:
155+
async def add_items(
156+
self,
157+
items: list[TResponseInputItem],
158+
wrapper: RunContextWrapper[Any] | None = None,
159+
) -> None:
149160
self.items.extend(items)
150161

151162
async def pop_item(self) -> TResponseInputItem | None:
@@ -333,8 +344,10 @@ async def test_get_items_preserves_default_limit_when_none_is_unset() -> None:
333344
)
334345

335346
assert isinstance(prepared_input, list)
336-
assert prepared_input[0]["content"] == "two"
337-
assert prepared_input[1]["content"] == "new"
347+
first = prepared_input[0]
348+
second = prepared_input[1]
349+
assert isinstance(first, dict) and first.get("content") == "two"
350+
assert isinstance(second, dict) and second.get("content") == "new"
338351
assert session.get_call_count == 1
339352

340353

0 commit comments

Comments
 (0)