|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -from typing import Callable |
| 5 | +from typing import TYPE_CHECKING, Callable |
6 | 6 |
|
7 | 7 | import redis |
8 | 8 |
|
| 9 | +if TYPE_CHECKING: |
| 10 | + import redis.asyncio as async_redis |
| 11 | + |
| 12 | + |
| 13 | +def _parse_schema_from_info(info: list) -> dict[str, str]: |
| 14 | + """Parse field types from FT.INFO response. |
| 15 | +
|
| 16 | + This is a pure function with no I/O operations, shared by both |
| 17 | + sync and async schema registries. |
| 18 | +
|
| 19 | + Args: |
| 20 | + info: The raw response from FT.INFO command. |
| 21 | +
|
| 22 | + Returns: |
| 23 | + Dictionary mapping field names to their types (e.g., {"title": "TEXT"}). |
| 24 | + """ |
| 25 | + schema = {} |
| 26 | + # Find the 'attributes' section in the info response |
| 27 | + for i, item in enumerate(info): |
| 28 | + # Handle bytes or string comparison |
| 29 | + item_str = item.decode("utf-8") if isinstance(item, bytes) else item |
| 30 | + if item_str == "attributes": |
| 31 | + attributes = info[i + 1] |
| 32 | + for attr in attributes: |
| 33 | + field_name = None |
| 34 | + field_type = None |
| 35 | + # Each attribute is a list like: |
| 36 | + # [b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...] |
| 37 | + for j, val in enumerate(attr): |
| 38 | + val_str = val.decode("utf-8") if isinstance(val, bytes) else val |
| 39 | + if val_str == "attribute" and j + 1 < len(attr): |
| 40 | + fn = attr[j + 1] |
| 41 | + field_name = fn.decode("utf-8") if isinstance(fn, bytes) else fn |
| 42 | + if val_str == "type" and j + 1 < len(attr): |
| 43 | + ft = attr[j + 1] |
| 44 | + field_type = ft.decode("utf-8") if isinstance(ft, bytes) else ft |
| 45 | + if field_name and field_type: |
| 46 | + schema[field_name] = field_type |
| 47 | + break |
| 48 | + return schema |
| 49 | + |
9 | 50 |
|
10 | 51 | class SchemaRegistry: |
11 | 52 | """Loads and caches index schemas from Redis. |
@@ -33,43 +74,12 @@ def _load_index_schema(self, index_name: str) -> None: |
33 | 74 | """Load schema for a single index.""" |
34 | 75 | try: |
35 | 76 | info = self._client.execute_command("FT.INFO", index_name) |
36 | | - schema = self._parse_schema_from_info(info) |
| 77 | + schema = _parse_schema_from_info(info) |
37 | 78 | self._schemas[index_name] = schema |
38 | 79 | except redis.ResponseError: |
39 | 80 | # Index doesn't exist or was deleted |
40 | 81 | self._schemas.pop(index_name, None) |
41 | 82 |
|
42 | | - def _parse_schema_from_info(self, info: list) -> dict[str, str]: |
43 | | - """Parse field types from FT.INFO response.""" |
44 | | - schema = {} |
45 | | - # Find the 'attributes' section in the info response |
46 | | - for i, item in enumerate(info): |
47 | | - # Handle bytes or string comparison |
48 | | - item_str = item.decode("utf-8") if isinstance(item, bytes) else item |
49 | | - if item_str == "attributes": |
50 | | - attributes = info[i + 1] |
51 | | - for attr in attributes: |
52 | | - field_name = None |
53 | | - field_type = None |
54 | | - # Each attribute is a list like: |
55 | | - # [b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...] |
56 | | - for j, val in enumerate(attr): |
57 | | - val_str = val.decode("utf-8") if isinstance(val, bytes) else val |
58 | | - if val_str == "attribute" and j + 1 < len(attr): |
59 | | - fn = attr[j + 1] |
60 | | - field_name = ( |
61 | | - fn.decode("utf-8") if isinstance(fn, bytes) else fn |
62 | | - ) |
63 | | - if val_str == "type" and j + 1 < len(attr): |
64 | | - ft = attr[j + 1] |
65 | | - field_type = ( |
66 | | - ft.decode("utf-8") if isinstance(ft, bytes) else ft |
67 | | - ) |
68 | | - if field_name and field_type: |
69 | | - schema[field_name] = field_type |
70 | | - break |
71 | | - return schema |
72 | | - |
73 | 83 | def get_field_type(self, index: str, field: str) -> str | None: |
74 | 84 | """Get field type for a given index and field. |
75 | 85 |
|
@@ -140,3 +150,66 @@ def process_pending_events(self) -> None: |
140 | 150 | self._schemas.pop(idx, None) |
141 | 151 | if self._on_change: |
142 | 152 | self._on_change("dropped", idx) |
| 153 | + |
| 154 | + |
| 155 | +class AsyncSchemaRegistry: |
| 156 | + """Async version of SchemaRegistry for use with redis.asyncio clients. |
| 157 | +
|
| 158 | + Loads and caches index schemas from Redis asynchronously. |
| 159 | + """ |
| 160 | + |
| 161 | + def __init__(self, redis_client: "async_redis.Redis") -> None: |
| 162 | + """Initialize with an async Redis client. |
| 163 | +
|
| 164 | + Args: |
| 165 | + redis_client: An async Redis client (redis.asyncio.Redis). |
| 166 | + """ |
| 167 | + self._client = redis_client |
| 168 | + self._schemas: dict[str, dict[str, str]] = {} |
| 169 | + |
| 170 | + async def load_all(self) -> None: |
| 171 | + """Load schemas for all indexes on the server. |
| 172 | +
|
| 173 | + Uses asyncio.gather() to load all index schemas concurrently. |
| 174 | + """ |
| 175 | + import asyncio |
| 176 | + |
| 177 | + self._schemas.clear() |
| 178 | + indexes = await self._client.execute_command("FT._LIST") |
| 179 | + # Decode bytes to strings |
| 180 | + decoded_indexes = [ |
| 181 | + idx.decode("utf-8") if isinstance(idx, bytes) else idx for idx in indexes |
| 182 | + ] |
| 183 | + # Load all schemas concurrently |
| 184 | + await asyncio.gather( |
| 185 | + *[self._load_index_schema(name) for name in decoded_indexes] |
| 186 | + ) |
| 187 | + |
| 188 | + async def _load_index_schema(self, index_name: str) -> None: |
| 189 | + """Load schema for a single index.""" |
| 190 | + try: |
| 191 | + info = await self._client.execute_command("FT.INFO", index_name) |
| 192 | + schema = _parse_schema_from_info(info) |
| 193 | + self._schemas[index_name] = schema |
| 194 | + except redis.ResponseError: |
| 195 | + # Index doesn't exist or was deleted |
| 196 | + self._schemas.pop(index_name, None) |
| 197 | + |
| 198 | + def get_field_type(self, index: str, field: str) -> str | None: |
| 199 | + """Get field type for a given index and field. |
| 200 | +
|
| 201 | + Returns None if index or field is unknown. |
| 202 | + """ |
| 203 | + schema = self._schemas.get(index, {}) |
| 204 | + return schema.get(field) |
| 205 | + |
| 206 | + def get_schema(self, index: str) -> dict[str, str]: |
| 207 | + """Get full schema for an index. |
| 208 | +
|
| 209 | + Returns empty dict if index is unknown. |
| 210 | + """ |
| 211 | + return self._schemas.get(index, {}) |
| 212 | + |
| 213 | + async def refresh(self, index_name: str) -> None: |
| 214 | + """Refresh schema for a single index.""" |
| 215 | + await self._load_index_schema(index_name) |
0 commit comments