Skip to content

Commit 9430075

Browse files
feat(valkey): shared cache config + ValkeyCache for A2A and file uploads
Extract duplicated Redis URL parsing into a shared cache_config utility. Introduce ValkeyCache as a lightweight async key/value cache using valkey-glide. Wire it into A2A task handling, agent card caching, and file upload caching. Part 1/4 of Valkey storage implementation. fix: async-safe embeddings and resilient drain_writes Add bytes→float validators on MemoryRecord and ItemState to handle Valkey returning embeddings as raw bytes. Make embed_texts() safe when called from an async context by using a thread pool. Improve drain_writes() with per-save timeouts and error logging instead of raising on failure. Part 3/4 of Valkey storage implementation. feat(valkey): ValkeyStorage vector memory backend Add ValkeyStorage, a distributed StorageBackend implementation using Valkey-GLIDE with Valkey Search for vector similarity. Wire it into Memory as the 'valkey' storage option. Pin scrapegraph-py<2 to fix unrelated upstream breakage. Part 4/4 of Valkey storage implementation. fix: use datetime.utcnow() for last_accessed consistency MemoryRecord defaults use utcnow() for created_at and last_accessed. Match that in ValkeyStorage.update_record() to avoid timezone inconsistency in recency scoring. feat(valkey): shared cache config + ValkeyCache for A2A and file uploads Extract duplicated Redis URL parsing into a shared cache_config utility. Introduce ValkeyCache as a lightweight async key/value cache using valkey-glide. Wire it into A2A task handling, agent card caching, and file upload caching. Part 1/4 of Valkey storage implementation. fix: handle non-numeric database path in cache URL parsing Extract _parse_db_from_path() helper that catches ValueError for paths like /mydb and defaults to 0 with a warning, instead of crashing. fix: async-safe embeddings and resilient drain_writes Add bytes→float validators on MemoryRecord and ItemState to handle Valkey returning embeddings as raw bytes. Make embed_texts() safe when called from an async context by using a thread pool. Improve drain_writes() with per-save timeouts and error logging instead of raising on failure. Part 3/4 of Valkey storage implementation. fix: catch concurrent.futures.TimeoutError for Python 3.10 compat In Python <3.11, concurrent.futures.TimeoutError is distinct from the builtin TimeoutError. Catch both so the timeout warning path works on all supported Python versions.
1 parent 264da82 commit 9430075

19 files changed

Lines changed: 8780 additions & 163 deletions

lib/crewai-files/src/crewai_files/cache/upload_cache.py

Lines changed: 163 additions & 100 deletions
Large diffs are not rendered by default.

lib/crewai/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ file-processing = [
110110
qdrant-edge = [
111111
"qdrant-edge-py>=0.6.0",
112112
]
113+
valkey = [
114+
"valkey-glide>=1.3.0",
115+
]
113116

114117

115118
[tool.uv]

lib/crewai/src/crewai/a2a/utils/agent_card.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
from typing import TYPE_CHECKING
1414

1515
from a2a.client.errors import A2AClientHTTPError
16-
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
17-
from aiocache import cached # type: ignore[import-untyped]
16+
from a2a.types import (
17+
AgentCapabilities,
18+
AgentCard,
19+
AgentSkill,
20+
)
21+
from aiocache import cached, caches # type: ignore[import-untyped]
1822
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
1923
import httpx
2024

@@ -32,6 +36,7 @@
3236
A2AAuthenticationFailedEvent,
3337
A2AConnectionErrorEvent,
3438
)
39+
from crewai.utilities.cache_config import get_aiocache_config
3540

3641

3742
if TYPE_CHECKING:
@@ -40,6 +45,18 @@
4045
from crewai.task import Task
4146

4247

48+
_cache_configured = False
49+
50+
51+
def _ensure_cache_configured() -> None:
52+
"""Configure aiocache on first use (lazy initialization)."""
53+
global _cache_configured
54+
if _cache_configured:
55+
return
56+
caches.set_config(get_aiocache_config())
57+
_cache_configured = True
58+
59+
4360
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
4461
"""Get TLS verify parameter from auth scheme.
4562
@@ -191,6 +208,7 @@ async def afetch_agent_card(
191208
else:
192209
auth_hash = _auth_store.compute_key("none", "")
193210
_auth_store.set(auth_hash, auth)
211+
_ensure_cache_configured()
194212
agent_card: AgentCard = await _afetch_agent_card_cached(
195213
endpoint, auth_hash, timeout
196214
)

lib/crewai/src/crewai/a2a/utils/task.py

Lines changed: 83 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from functools import wraps
1010
import json
1111
import logging
12-
import os
12+
import threading
1313
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
14-
from urllib.parse import urlparse
1514

1615
from a2a.server.agent_execution import RequestContext
1716
from a2a.server.events import EventQueue
@@ -38,7 +37,6 @@
3837
from a2a.utils.errors import ServerError
3938
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
4039
from pydantic import BaseModel
41-
from typing_extensions import TypedDict
4240

4341
from crewai.a2a.utils.agent_card import _get_server_config
4442
from crewai.a2a.utils.content_type import validate_message_parts
@@ -50,12 +48,18 @@
5048
A2AServerTaskStartedEvent,
5149
)
5250
from crewai.task import Task
51+
from crewai.utilities.cache_config import (
52+
get_aiocache_config,
53+
parse_cache_url,
54+
use_valkey_cache,
55+
)
5356
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
5457

5558

5659
if TYPE_CHECKING:
5760
from crewai.a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
5861
from crewai.agent import Agent
62+
from crewai.memory.storage.valkey_cache import ValkeyCache
5963

6064

6165
logger = logging.getLogger(__name__)
@@ -64,52 +68,49 @@
6468
T = TypeVar("T")
6569

6670

67-
class RedisCacheConfig(TypedDict, total=False):
68-
"""Configuration for aiocache Redis backend."""
71+
# ---------------------------------------------------------------------------
72+
# Lazy cache initialisation
73+
# ---------------------------------------------------------------------------
6974

70-
cache: str
71-
endpoint: str
72-
port: int
73-
db: int
74-
password: str
75+
_task_cache: ValkeyCache | None = None
76+
_cache_initialized = False
77+
_cache_init_lock = threading.Lock()
7578

79+
# Configure aiocache at import time (matches upstream behaviour).
80+
# This is safe — it only touches aiocache, no optional dependencies.
81+
# The Valkey path is deferred to _ensure_task_cache() to avoid importing
82+
# valkey-glide at module level (it may not be installed).
83+
if not use_valkey_cache():
84+
caches.set_config(get_aiocache_config())
7685

77-
def _parse_redis_url(url: str) -> RedisCacheConfig:
78-
"""Parse a Redis URL into aiocache configuration.
7986

80-
Args:
81-
url: Redis connection URL (e.g., redis://localhost:6379/0).
87+
def _ensure_task_cache() -> None:
88+
"""Initialise the Valkey task cache on first use (thread-safe).
8289
83-
Returns:
84-
Configuration dict for aiocache.RedisCache.
90+
For the aiocache path, configuration happens at module level above.
91+
This function only needs to run for the Valkey path.
8592
"""
86-
parsed = urlparse(url)
87-
config: RedisCacheConfig = {
88-
"cache": "aiocache.RedisCache",
89-
"endpoint": parsed.hostname or "localhost",
90-
"port": parsed.port or 6379,
91-
}
92-
if parsed.path and parsed.path != "/":
93-
try:
94-
config["db"] = int(parsed.path.lstrip("/"))
95-
except ValueError:
96-
pass
97-
if parsed.password:
98-
config["password"] = parsed.password
99-
return config
100-
93+
global _task_cache, _cache_initialized
94+
if _cache_initialized:
95+
return
96+
97+
with _cache_init_lock:
98+
if _cache_initialized:
99+
return
100+
101+
if use_valkey_cache():
102+
from crewai.memory.storage.valkey_cache import ValkeyCache
103+
104+
conn = parse_cache_url() or {}
105+
_task_cache = ValkeyCache(
106+
host=conn.get("host", "localhost"),
107+
port=conn.get("port", 6379),
108+
db=conn.get("db", 0),
109+
password=conn.get("password"),
110+
default_ttl=3600,
111+
)
101112

102-
_redis_url = os.environ.get("REDIS_URL")
103-
104-
caches.set_config(
105-
{
106-
"default": _parse_redis_url(_redis_url)
107-
if _redis_url
108-
else {
109-
"cache": "aiocache.SimpleMemoryCache",
110-
}
111-
}
112-
)
113+
_cache_initialized = True
113114

114115

115116
def cancellable(
@@ -130,6 +131,8 @@ def cancellable(
130131
@wraps(fn)
131132
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
132133
"""Wrap function with cancellation monitoring."""
134+
_ensure_task_cache()
135+
133136
context: RequestContext | None = None
134137
for arg in args:
135138
if isinstance(arg, RequestContext):
@@ -142,19 +145,34 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
142145
return await fn(*args, **kwargs)
143146

144147
task_id = context.task_id
145-
cache = caches.get("default")
146148

147-
async def poll_for_cancel() -> bool:
148-
"""Poll cache for cancellation flag."""
149+
async def poll_for_cancel_valkey() -> bool:
150+
"""Poll ValkeyCache for cancellation flag."""
151+
while True:
152+
if _task_cache is not None and await _task_cache.get(
153+
f"cancel:{task_id}"
154+
):
155+
return True
156+
await asyncio.sleep(0.1)
157+
158+
async def poll_for_cancel_aiocache() -> bool:
159+
"""Poll aiocache for cancellation flag."""
160+
cache = caches.get("default")
149161
while True:
150162
if await cache.get(f"cancel:{task_id}"):
151163
return True
152164
await asyncio.sleep(0.1)
153165

154166
async def watch_for_cancel() -> bool:
155167
"""Watch for cancellation events via pub/sub or polling."""
168+
if _task_cache is not None:
169+
# ValkeyCache: use polling (pub/sub not implemented yet)
170+
return await poll_for_cancel_valkey()
171+
172+
# aiocache: use pub/sub if Redis, otherwise poll
173+
cache = caches.get("default")
156174
if isinstance(cache, SimpleMemoryCache):
157-
return await poll_for_cancel()
175+
return await poll_for_cancel_aiocache()
158176

159177
try:
160178
client = cache.client
@@ -168,7 +186,7 @@ async def watch_for_cancel() -> bool:
168186
"Cancel watcher Redis error, falling back to polling",
169187
extra={"task_id": task_id, "error": str(e)},
170188
)
171-
return await poll_for_cancel()
189+
return await poll_for_cancel_aiocache()
172190
return False
173191

174192
execute_task = asyncio.create_task(fn(*args, **kwargs))
@@ -190,7 +208,12 @@ async def watch_for_cancel() -> bool:
190208
cancel_watch.cancel()
191209
return execute_task.result()
192210
finally:
193-
await cache.delete(f"cancel:{task_id}")
211+
# Clean up cancellation flag
212+
if _task_cache is not None:
213+
await _task_cache.delete(f"cancel:{task_id}")
214+
else:
215+
cache = caches.get("default")
216+
await cache.delete(f"cancel:{task_id}")
194217

195218
return wrapper
196219

@@ -475,18 +498,25 @@ async def cancel(
475498
if task_id is None or context_id is None:
476499
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
477500

501+
_ensure_task_cache()
502+
478503
if context.current_task and context.current_task.status.state in (
479504
TaskState.completed,
480505
TaskState.failed,
481506
TaskState.canceled,
482507
):
483508
return context.current_task
484509

485-
cache = caches.get("default")
486-
487-
await cache.set(f"cancel:{task_id}", True, ttl=3600)
488-
if not isinstance(cache, SimpleMemoryCache):
489-
await cache.client.publish(f"cancel:{task_id}", "cancel")
510+
if _task_cache is not None:
511+
# Use ValkeyCache
512+
await _task_cache.set(f"cancel:{task_id}", True, ttl=3600)
513+
# Note: pub/sub not implemented for ValkeyCache yet, relies on polling
514+
else:
515+
# Use aiocache
516+
cache = caches.get("default")
517+
await cache.set(f"cancel:{task_id}", True, ttl=3600)
518+
if not isinstance(cache, SimpleMemoryCache):
519+
await cache.client.publish(f"cancel:{task_id}", "cancel")
490520

491521
await event_queue.enqueue_event(
492522
TaskStatusUpdateEvent(

lib/crewai/src/crewai/memory/encoding_flow.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any
1919
from uuid import uuid4
2020

21-
from pydantic import BaseModel, Field
21+
from pydantic import BaseModel, Field, field_validator
2222

2323
from crewai.flow.flow import Flow, listen, start
2424
from crewai.memory.analyze import (
@@ -68,6 +68,29 @@ class ItemState(BaseModel):
6868
plan: ConsolidationPlan | None = None
6969
result_record: MemoryRecord | None = None
7070

71+
@field_validator("similar_records", "result_record", mode="before")
72+
@classmethod
73+
def ensure_embedding_is_list(cls, v: Any) -> Any:
74+
"""Ensure MemoryRecord embeddings are list[float], not bytes.
75+
76+
Delegates to MemoryRecord.validate_embedding for consistent behavior
77+
(e.g. empty bytes → None).
78+
"""
79+
if v is None:
80+
return None
81+
if isinstance(v, list):
82+
for record in v:
83+
if isinstance(record, MemoryRecord) and isinstance(
84+
record.embedding, bytes
85+
):
86+
record.embedding = MemoryRecord.validate_embedding(
87+
record.embedding
88+
)
89+
return v
90+
if isinstance(v, MemoryRecord) and isinstance(v.embedding, bytes):
91+
v.embedding = MemoryRecord.validate_embedding(v.embedding)
92+
return v
93+
7194

7295
class EncodingState(BaseModel):
7396
"""Batch-level state for the encoding flow."""

0 commit comments

Comments
 (0)