diff --git a/langgraph/checkpoint/redis/aio.py b/langgraph/checkpoint/redis/aio.py index 92c8359..bcd6689 100644 --- a/langgraph/checkpoint/redis/aio.py +++ b/langgraph/checkpoint/redis/aio.py @@ -101,7 +101,9 @@ def __init__( checkpoint_prefix=checkpoint_prefix, checkpoint_write_prefix=checkpoint_write_prefix, ) - self.loop = asyncio.get_running_loop() + # Deferred: the event loop is captured in asetup() so that the saver can + # be constructed outside an async context (Issue #179). + self.loop: Optional[asyncio.AbstractEventLoop] = None # Instance-level cache for frequently used keys (limited size to prevent memory issues) self._key_cache: Dict[str, str] = {} @@ -243,6 +245,13 @@ async def __aexit__( async def asetup(self) -> None: """Set up the checkpoint saver.""" + # Capture the running event loop here so that sync wrapper methods + # (get_tuple, put, put_writes, …) can dispatch coroutines to it via + # asyncio.run_coroutine_threadsafe. Deferring this to asetup() instead + # of __init__ lets callers construct the saver outside an async context + # (Issue #179). + self.loop = asyncio.get_running_loop() + self.create_indexes() await self.checkpoints_index.create(overwrite=False) await self.checkpoint_writes_index.create(overwrite=False) @@ -1307,6 +1316,20 @@ def put_writes( task_id (str): Identifier for the task creating the writes. task_path (str): Path of the task creating the writes. """ + if self.loop is None: + raise RuntimeError( + "AsyncRedisSaver must be set up before calling synchronous methods. " + "Call `await saver.asetup()` or use `async with saver:` first." + ) + try: + if asyncio.get_running_loop() is self.loop: + raise asyncio.InvalidStateError( + "Synchronous calls to AsyncRedisSaver are only allowed from a " + "different thread. From the main thread, use the async interface. " + "For example, use `await checkpointer.aput_writes(...)`." + ) + except RuntimeError: + pass return asyncio.run_coroutine_threadsafe( self.aput_writes(config, writes, task_id), self.loop ).result() @@ -1315,12 +1338,17 @@ def get_channel_values( self, thread_id: str, checkpoint_ns: str = "", checkpoint_id: str = "" ) -> Dict[str, Any]: """Retrieve channel_values using efficient FT.SEARCH with checkpoint_id (sync wrapper).""" + if self.loop is None: + raise RuntimeError( + "AsyncRedisSaver must be set up before calling synchronous methods. " + "Call `await saver.asetup()` or use `async with saver:` first." + ) try: if asyncio.get_running_loop() is self.loop: raise asyncio.InvalidStateError( "Synchronous calls to AsyncRedisSaver are only allowed from a " - "different thread. From the main thread, use the async interface." - "For example, use `await checkpointer.get_channel_values(...)`." + "different thread. From the main thread, use the async interface. " + "For example, use `await checkpointer.aget_channel_values(...)`." ) except RuntimeError: pass @@ -1345,6 +1373,11 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: Raises: asyncio.InvalidStateError: If called from the wrong thread/event loop """ + if self.loop is None: + raise RuntimeError( + "AsyncRedisSaver must be set up before calling synchronous methods. " + "Call `await saver.asetup()` or use `async with saver:` first." + ) try: # check if we are in the main thread, only bg threads can block if asyncio.get_running_loop() is self.loop: @@ -1381,6 +1414,11 @@ def put( Raises: asyncio.InvalidStateError: If called from the wrong thread/event loop """ + if self.loop is None: + raise RuntimeError( + "AsyncRedisSaver must be set up before calling synchronous methods. " + "Call `await saver.asetup()` or use `async with saver:` first." + ) try: # check if we are in the main thread, only bg threads can block if asyncio.get_running_loop() is self.loop: diff --git a/langgraph/checkpoint/redis/ashallow.py b/langgraph/checkpoint/redis/ashallow.py index 4550782..db09fd8 100644 --- a/langgraph/checkpoint/redis/ashallow.py +++ b/langgraph/checkpoint/redis/ashallow.py @@ -77,7 +77,9 @@ def __init__( checkpoint_prefix=checkpoint_prefix, checkpoint_write_prefix=checkpoint_write_prefix, ) - self.loop = asyncio.get_running_loop() + # Deferred: the event loop is captured in asetup() so that the saver can + # be constructed outside an async context (Issue #179). + self.loop: Optional[asyncio.AbstractEventLoop] = None # Instance-level cache for frequently used keys (limited size to prevent memory issues) self._key_cache: Dict[str, str] = {} @@ -139,6 +141,13 @@ async def from_conn_string( async def asetup(self) -> None: """Initialize Redis indexes asynchronously.""" + # Capture the running event loop here so that sync wrapper methods + # (get_tuple, put, put_writes, …) can dispatch coroutines to it via + # asyncio.run_coroutine_threadsafe. Deferring this to asetup() instead + # of __init__ lets callers construct the saver outside an async context + # (Issue #179). + self.loop = asyncio.get_running_loop() + await self.checkpoints_index.create(overwrite=False) await self.checkpoint_writes_index.create(overwrite=False) @@ -725,6 +734,11 @@ def create_indexes(self) -> None: def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Retrieve a checkpoint tuple from Redis synchronously.""" + if self.loop is None: + raise RuntimeError( + "AsyncShallowRedisSaver must be set up before calling synchronous methods. " + "Call `await saver.asetup()` or use `async with saver:` first." + ) try: if asyncio.get_running_loop() is self.loop: raise asyncio.InvalidStateError( @@ -747,6 +761,20 @@ def put( new_versions: ChannelVersions, ) -> RunnableConfig: """Store only the latest checkpoint synchronously.""" + if self.loop is None: + raise RuntimeError( + "AsyncShallowRedisSaver must be set up before calling synchronous methods. " + "Call `await saver.asetup()` or use `async with saver:` first." + ) + try: + if asyncio.get_running_loop() is self.loop: + raise asyncio.InvalidStateError( + "Synchronous calls to AsyncShallowRedisSaver are only allowed from a " + "different thread. From the main thread, use the async interface. " + "For example, use `await checkpointer.aput(...)`." + ) + except RuntimeError: + pass return asyncio.run_coroutine_threadsafe( self.aput(config, checkpoint, metadata, new_versions), self.loop ).result() @@ -759,6 +787,20 @@ def put_writes( task_path: str = "", ) -> None: """Store intermediate writes synchronously.""" + if self.loop is None: + raise RuntimeError( + "AsyncShallowRedisSaver must be set up before calling synchronous methods. " + "Call `await saver.asetup()` or use `async with saver:` first." + ) + try: + if asyncio.get_running_loop() is self.loop: + raise asyncio.InvalidStateError( + "Synchronous calls to AsyncShallowRedisSaver are only allowed from a " + "different thread. From the main thread, use the async interface. " + "For example, use `await checkpointer.aput_writes(...)`." + ) + except RuntimeError: + pass return asyncio.run_coroutine_threadsafe( self.aput_writes(config, writes, task_id), self.loop ).result() @@ -771,6 +813,11 @@ def get_channel_values( channel_versions: Optional[Dict[str, Any]] = None, ) -> dict[str, Any]: """Retrieve channel_values dictionary with properly constructed message objects (sync wrapper).""" + if self.loop is None: + raise RuntimeError( + "AsyncShallowRedisSaver must be set up before calling synchronous methods. " + "Call `await saver.asetup()` or use `async with saver:` first." + ) try: if asyncio.get_running_loop() is self.loop: raise asyncio.InvalidStateError( diff --git a/tests/test_async.py b/tests/test_async.py index 62163c0..03952b2 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -848,3 +848,78 @@ async def test_root_graph_checkpoint( checkpoints = [c async for c in checkpointer.alist(config)] assert len(checkpoints) > 0 assert checkpoints[-1].checkpoint["id"] == latest["id"] + + +# --- Issue #179: AsyncRedisSaver construction outside async context --- + + +def test_async_redis_saver_construction_outside_event_loop(redis_url: str) -> None: + """AsyncRedisSaver should be constructable outside an async context (Issue #179). + + Previously, AsyncRedisSaver.__init__ called asyncio.get_running_loop() which + raised RuntimeError when no event loop was running. + """ + # This must not raise RuntimeError even when there is no running event loop + saver = AsyncRedisSaver(redis_url) + assert saver is not None + # Loop should be None until asetup() is called + assert saver.loop is None + + +def test_async_redis_saver_construction_with_client_outside_event_loop( + redis_url: str, +) -> None: + """AsyncRedisSaver should accept a pre-built client without a running loop (Issue #179). + + The typical use-case from the issue: constructing the saver synchronously, + then setting up (and using it) later inside an async lifespan handler. + """ + from redis.asyncio import Redis as AsyncRedis + + client = AsyncRedis.from_url(redis_url) + try: + saver = AsyncRedisSaver(redis_client=client) + assert saver is not None + assert saver.loop is None + finally: + asyncio.run(client.aclose()) + + +@pytest.mark.asyncio +async def test_async_redis_saver_loop_captured_in_asetup(redis_url: str) -> None: + """asetup() must capture the running event loop so sync wrappers work (Issue #179).""" + saver = AsyncRedisSaver(redis_url) + assert saver.loop is None # not yet set + + await saver.asetup() + + # After asetup the loop attribute must point to the current running loop + assert saver.loop is not None + assert saver.loop is asyncio.get_running_loop() + + await saver._redis.aclose() + + +@pytest.mark.asyncio +async def test_async_redis_saver_context_manager_after_sync_construction( + redis_url: str, +) -> None: + """Saver built before entering the async context manager must still work.""" + # Construct before entering `async with`; in this async test a loop is already + # running, but this still verifies the saver is usable end-to-end once setup + # happens on context-manager entry. + saver = AsyncRedisSaver(redis_url) + + async with saver: + # After entering the context the loop must be set + assert saver.loop is asyncio.get_running_loop() + + # Basic functional smoke test + config: RunnableConfig = { + "configurable": {"thread_id": "issue-179-test", "checkpoint_ns": ""} + } + chk: Checkpoint = empty_checkpoint() + meta: CheckpointMetadata = {"source": "input", "step": 0, "writes": {}} + await saver.aput(config, chk, meta, {}) + result = await saver.aget_tuple(config) + assert result is not None diff --git a/tests/test_shallow_async.py b/tests/test_shallow_async.py index 1ad53cb..7f48982 100644 --- a/tests/test_shallow_async.py +++ b/tests/test_shallow_async.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, AsyncGenerator, Dict import pytest @@ -494,3 +495,72 @@ async def test_shallow_redis_saver_inline_storage(redis_url: str) -> None: # Clean up test data await redis_client.flushdb() await redis_client.aclose() + + +# --- Issue #179: AsyncShallowRedisSaver construction outside async context --- + + +def test_async_shallow_redis_saver_construction_outside_event_loop( + redis_url: str, +) -> None: + """AsyncShallowRedisSaver should be constructable outside an async context (Issue #179). + + Previously, AsyncShallowRedisSaver.__init__ called asyncio.get_running_loop() which + raised RuntimeError when no event loop was running. + """ + # This must not raise RuntimeError even when there is no running event loop + saver = AsyncShallowRedisSaver(redis_url) + assert saver is not None + # Loop should be None until asetup() is called + assert saver.loop is None + + +def test_async_shallow_redis_saver_construction_with_client_outside_event_loop( + redis_url: str, +) -> None: + """AsyncShallowRedisSaver accepts a pre-built client without a running loop (Issue #179).""" + from redis.asyncio import Redis as AsyncRedis + + client = AsyncRedis.from_url(redis_url) + try: + saver = AsyncShallowRedisSaver(redis_client=client) + assert saver is not None + assert saver.loop is None + finally: + asyncio.run(client.aclose()) + + +@pytest.mark.asyncio +async def test_async_shallow_redis_saver_loop_captured_in_asetup( + redis_url: str, +) -> None: + """asetup() must capture the running event loop so sync wrappers work (Issue #179).""" + saver = AsyncShallowRedisSaver(redis_url) + assert saver.loop is None # not yet set + + await saver.asetup() + + assert saver.loop is not None + assert saver.loop is asyncio.get_running_loop() + + await saver._redis.aclose() + + +@pytest.mark.asyncio +async def test_async_shallow_redis_saver_context_manager_after_sync_construction( + redis_url: str, +) -> None: + """Saver constructed before entering the async context manager must still work.""" + saver = AsyncShallowRedisSaver(redis_url) + + async with saver: + assert saver.loop is asyncio.get_running_loop() + + config: RunnableConfig = { + "configurable": {"thread_id": "issue-179-shallow-test", "checkpoint_ns": ""} + } + chk: Checkpoint = empty_checkpoint() + meta: CheckpointMetadata = {"source": "input", "step": 0, "writes": {}} + await saver.aput(config, chk, meta, {}) + result = await saver.aget_tuple(config) + assert result is not None