99from functools import wraps
1010import json
1111import logging
12- import os
12+ import threading
1313from typing import TYPE_CHECKING , Any , ParamSpec , TypeVar , cast
14- from urllib .parse import urlparse
1514
1615from a2a .server .agent_execution import RequestContext
1716from a2a .server .events import EventQueue
3837from a2a .utils .errors import ServerError
3938from aiocache import SimpleMemoryCache , caches # type: ignore[import-untyped]
4039from pydantic import BaseModel
41- from typing_extensions import TypedDict
4240
4341from crewai .a2a .utils .agent_card import _get_server_config
4442from crewai .a2a .utils .content_type import validate_message_parts
5048 A2AServerTaskStartedEvent ,
5149)
5250from crewai .task import Task
51+ from crewai .utilities .cache_config import (
52+ get_aiocache_config ,
53+ parse_cache_url ,
54+ use_valkey_cache ,
55+ )
5356from crewai .utilities .pydantic_schema_utils import create_model_from_schema
5457
5558
5659if TYPE_CHECKING :
5760 from crewai .a2a .extensions .server import ExtensionContext , ServerExtensionRegistry
5861 from crewai .agent import Agent
62+ from crewai .memory .storage .valkey_cache import ValkeyCache
5963
6064
6165logger = logging .getLogger (__name__ )
6468T = TypeVar ("T" )
6569
6670
67- class RedisCacheConfig (TypedDict , total = False ):
68- """Configuration for aiocache Redis backend."""
71+ # ---------------------------------------------------------------------------
72+ # Lazy cache initialisation
73+ # ---------------------------------------------------------------------------
6974
70- cache : str
71- endpoint : str
72- port : int
73- db : int
74- password : str
75+ _task_cache : ValkeyCache | None = None
76+ _cache_initialized = False
77+ _cache_init_lock = threading .Lock ()
7578
79+ # Configure aiocache at import time (matches upstream behaviour).
80+ # This is safe — it only touches aiocache, no optional dependencies.
81+ # The Valkey path is deferred to _ensure_task_cache() to avoid importing
82+ # valkey-glide at module level (it may not be installed).
83+ if not use_valkey_cache ():
84+ caches .set_config (get_aiocache_config ())
7685
77- def _parse_redis_url (url : str ) -> RedisCacheConfig :
78- """Parse a Redis URL into aiocache configuration.
7986
80- Args :
81- url: Redis connection URL (e.g., redis://localhost:6379/0 ).
87+ def _ensure_task_cache () -> None :
88+ """Initialise the Valkey task cache on first use (thread-safe ).
8289
83- Returns:
84- Configuration dict for aiocache.RedisCache .
90+ For the aiocache path, configuration happens at module level above.
91+ This function only needs to run for the Valkey path .
8592 """
86- parsed = urlparse (url )
87- config : RedisCacheConfig = {
88- "cache" : "aiocache.RedisCache" ,
89- "endpoint" : parsed .hostname or "localhost" ,
90- "port" : parsed .port or 6379 ,
91- }
92- if parsed .path and parsed .path != "/" :
93- try :
94- config ["db" ] = int (parsed .path .lstrip ("/" ))
95- except ValueError :
96- pass
97- if parsed .password :
98- config ["password" ] = parsed .password
99- return config
100-
93+ global _task_cache , _cache_initialized
94+ if _cache_initialized :
95+ return
96+
97+ with _cache_init_lock :
98+ if _cache_initialized :
99+ return
100+
101+ if use_valkey_cache ():
102+ from crewai .memory .storage .valkey_cache import ValkeyCache
103+
104+ conn = parse_cache_url () or {}
105+ _task_cache = ValkeyCache (
106+ host = conn .get ("host" , "localhost" ),
107+ port = conn .get ("port" , 6379 ),
108+ db = conn .get ("db" , 0 ),
109+ password = conn .get ("password" ),
110+ default_ttl = 3600 ,
111+ )
101112
102- _redis_url = os .environ .get ("REDIS_URL" )
103-
104- caches .set_config (
105- {
106- "default" : _parse_redis_url (_redis_url )
107- if _redis_url
108- else {
109- "cache" : "aiocache.SimpleMemoryCache" ,
110- }
111- }
112- )
113+ _cache_initialized = True
113114
114115
115116def cancellable (
@@ -130,6 +131,8 @@ def cancellable(
130131 @wraps (fn )
131132 async def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T :
132133 """Wrap function with cancellation monitoring."""
134+ _ensure_task_cache ()
135+
133136 context : RequestContext | None = None
134137 for arg in args :
135138 if isinstance (arg , RequestContext ):
@@ -142,19 +145,34 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
142145 return await fn (* args , ** kwargs )
143146
144147 task_id = context .task_id
145- cache = caches .get ("default" )
146148
147- async def poll_for_cancel () -> bool :
148- """Poll cache for cancellation flag."""
149+ async def poll_for_cancel_valkey () -> bool :
150+ """Poll ValkeyCache for cancellation flag."""
151+ while True :
152+ if _task_cache is not None and await _task_cache .get (
153+ f"cancel:{ task_id } "
154+ ):
155+ return True
156+ await asyncio .sleep (0.1 )
157+
158+ async def poll_for_cancel_aiocache () -> bool :
159+ """Poll aiocache for cancellation flag."""
160+ cache = caches .get ("default" )
149161 while True :
150162 if await cache .get (f"cancel:{ task_id } " ):
151163 return True
152164 await asyncio .sleep (0.1 )
153165
154166 async def watch_for_cancel () -> bool :
155167 """Watch for cancellation events via pub/sub or polling."""
168+ if _task_cache is not None :
169+ # ValkeyCache: use polling (pub/sub not implemented yet)
170+ return await poll_for_cancel_valkey ()
171+
172+ # aiocache: use pub/sub if Redis, otherwise poll
173+ cache = caches .get ("default" )
156174 if isinstance (cache , SimpleMemoryCache ):
157- return await poll_for_cancel ()
175+ return await poll_for_cancel_aiocache ()
158176
159177 try :
160178 client = cache .client
@@ -168,7 +186,7 @@ async def watch_for_cancel() -> bool:
168186 "Cancel watcher Redis error, falling back to polling" ,
169187 extra = {"task_id" : task_id , "error" : str (e )},
170188 )
171- return await poll_for_cancel ()
189+ return await poll_for_cancel_aiocache ()
172190 return False
173191
174192 execute_task = asyncio .create_task (fn (* args , ** kwargs ))
@@ -190,7 +208,12 @@ async def watch_for_cancel() -> bool:
190208 cancel_watch .cancel ()
191209 return execute_task .result ()
192210 finally :
193- await cache .delete (f"cancel:{ task_id } " )
211+ # Clean up cancellation flag
212+ if _task_cache is not None :
213+ await _task_cache .delete (f"cancel:{ task_id } " )
214+ else :
215+ cache = caches .get ("default" )
216+ await cache .delete (f"cancel:{ task_id } " )
194217
195218 return wrapper
196219
@@ -475,18 +498,25 @@ async def cancel(
475498 if task_id is None or context_id is None :
476499 raise ServerError (InvalidParamsError (message = "task_id and context_id required" ))
477500
501+ _ensure_task_cache ()
502+
478503 if context .current_task and context .current_task .status .state in (
479504 TaskState .completed ,
480505 TaskState .failed ,
481506 TaskState .canceled ,
482507 ):
483508 return context .current_task
484509
485- cache = caches .get ("default" )
486-
487- await cache .set (f"cancel:{ task_id } " , True , ttl = 3600 )
488- if not isinstance (cache , SimpleMemoryCache ):
489- await cache .client .publish (f"cancel:{ task_id } " , "cancel" )
510+ if _task_cache is not None :
511+ # Use ValkeyCache
512+ await _task_cache .set (f"cancel:{ task_id } " , True , ttl = 3600 )
513+ # Note: pub/sub not implemented for ValkeyCache yet, relies on polling
514+ else :
515+ # Use aiocache
516+ cache = caches .get ("default" )
517+ await cache .set (f"cancel:{ task_id } " , True , ttl = 3600 )
518+ if not isinstance (cache , SimpleMemoryCache ):
519+ await cache .client .publish (f"cancel:{ task_id } " , "cancel" )
490520
491521 await event_queue .enqueue_event (
492522 TaskStatusUpdateEvent (
0 commit comments