diff --git a/sql_redis/__init__.py b/sql_redis/__init__.py index 1eacaf4..7d1bda9 100644 --- a/sql_redis/__init__.py +++ b/sql_redis/__init__.py @@ -1,6 +1,17 @@ """SQL to Redis command translation utility.""" +from sql_redis.executor import AsyncExecutor, Executor, QueryResult +from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry from sql_redis.translator import TranslatedQuery, Translator from sql_redis.version import __version__ -__all__ = ["Translator", "TranslatedQuery", "__version__"] +__all__ = [ + "Translator", + "TranslatedQuery", + "SchemaRegistry", + "AsyncSchemaRegistry", + "Executor", + "AsyncExecutor", + "QueryResult", + "__version__", +] diff --git a/sql_redis/executor.py b/sql_redis/executor.py index a9a0952..5d7655a 100644 --- a/sql_redis/executor.py +++ b/sql_redis/executor.py @@ -4,13 +4,96 @@ import re from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any import redis -from sql_redis.schema import SchemaRegistry +from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry from sql_redis.translator import Translator +if TYPE_CHECKING: + import redis.asyncio as async_redis + + +def _substitute_params(sql: str, params: dict[str, Any]) -> str: + """Substitute parameter placeholders in SQL with actual values. + + This is a pure function with no I/O operations, shared by both + sync and async executors. + + Uses token-based approach: splits SQL on :param patterns, then rebuilds + with substituted values. This approach solves two critical bugs: + + 1. PARTIAL MATCHING BUG: Prevents :id from matching inside :product_id + by treating each :identifier as a complete token + + 2. QUOTE ESCAPING BUG: Properly escapes single quotes in string values + using SQL standard (single quote -> double single quote) + + Args: + sql: The SQL string with :param placeholders. + params: Dictionary mapping parameter names to values. + + Returns: + SQL string with parameters substituted. + + Implementation Details: + - Uses regex to split on parameter patterns: :[a-zA-Z_][a-zA-Z0-9_]* + - Keeps delimiters (the :param tokens) in the split result + - Iterates through tokens, substituting matched parameters + - String values are wrapped in single quotes with proper escaping + - Numeric values are converted to strings + - Bytes values (e.g., vectors) are NOT substituted here + + Known Limitations: + - Colons in string literals: SQL like "WHERE x = 'test:value'" would + theoretically match :value as a parameter. However, this is not a + practical issue because: + 1. Users pass values via parameters, not hardcoded in SQL + 2. The translator has its own handling of string literals + 3. No real-world use cases have been identified + - Parameter names are case-sensitive (:id != :ID) + - Only handles int, float, str types; other types keep placeholder + """ + if not params: + return sql + + # Split SQL on :param patterns, keeping the delimiters + # Pattern matches : followed by valid identifier: + # [a-zA-Z_] - First char must be letter or underscore + # [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore + # This prevents partial matching: :id and :product_id are separate tokens + tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql) + + result = [] + for token in tokens: + if token.startswith(":"): + # This is a parameter placeholder + key = token[1:] # Remove leading : + if key in params: + value = params[key] + if isinstance(value, (int, float)): + # Numeric values: convert to string + result.append(str(value)) + elif isinstance(value, str): + # String values: wrap in quotes and escape single quotes + # SQL standard: ' -> '' (double single quote) + # This fixes the quote escaping bug + escaped = value.replace("'", "''") + result.append(f"'{escaped}'") + else: + # Other types (bytes, None, bool, list, etc.): + # Keep placeholder as-is (handled elsewhere or unsupported) + result.append(token) + else: + # Parameter not provided: keep placeholder as-is + result.append(token) + else: + # Not a parameter: keep as-is + result.append(token) + + return "".join(result) + @dataclass class QueryResult: @@ -23,94 +106,18 @@ class QueryResult: class Executor: """Executes SQL queries against Redis.""" - def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry): + def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry) -> None: """Initialize executor with Redis client and schema registry.""" self._client = client self._schema_registry = schema_registry self._translator = Translator(schema_registry) - def _substitute_params(self, sql: str, params: dict[str, Any]) -> str: - """Substitute parameter placeholders in SQL with actual values. - - Uses token-based approach: splits SQL on :param patterns, then rebuilds - with substituted values. This approach solves two critical bugs: - - 1. PARTIAL MATCHING BUG: Prevents :id from matching inside :product_id - by treating each :identifier as a complete token - - 2. QUOTE ESCAPING BUG: Properly escapes single quotes in string values - using SQL standard (single quote -> double single quote) - - Args: - sql: The SQL string with :param placeholders. - params: Dictionary mapping parameter names to values. - - Returns: - SQL string with parameters substituted. - - Implementation Details: - - Uses regex to split on parameter patterns: :[a-zA-Z_][a-zA-Z0-9_]* - - Keeps delimiters (the :param tokens) in the split result - - Iterates through tokens, substituting matched parameters - - String values are wrapped in single quotes with proper escaping - - Numeric values are converted to strings - - Bytes values (e.g., vectors) are NOT substituted here - - Known Limitations: - - Colons in string literals: SQL like "WHERE x = 'test:value'" would - theoretically match :value as a parameter. However, this is not a - practical issue because: - 1. Users pass values via parameters, not hardcoded in SQL - 2. The translator has its own handling of string literals - 3. No real-world use cases have been identified - - Parameter names are case-sensitive (:id != :ID) - - Only handles int, float, str types; other types keep placeholder - """ - if not params: - return sql - - # Split SQL on :param patterns, keeping the delimiters - # Pattern matches : followed by valid identifier: - # [a-zA-Z_] - First char must be letter or underscore - # [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore - # This prevents partial matching: :id and :product_id are separate tokens - tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql) - - result = [] - for token in tokens: - if token.startswith(":"): - # This is a parameter placeholder - key = token[1:] # Remove leading : - if key in params: - value = params[key] - if isinstance(value, (int, float)): - # Numeric values: convert to string - result.append(str(value)) - elif isinstance(value, str): - # String values: wrap in quotes and escape single quotes - # SQL standard: ' -> '' (double single quote) - # This fixes the quote escaping bug - escaped = value.replace("'", "''") - result.append(f"'{escaped}'") - else: - # Other types (bytes, None, bool, list, etc.): - # Keep placeholder as-is (handled elsewhere or unsupported) - result.append(token) - else: - # Parameter not provided: keep placeholder as-is - result.append(token) - else: - # Not a parameter: keep as-is - result.append(token) - - return "".join(result) - def execute(self, sql: str, *, params: dict | None = None) -> QueryResult: """Execute a SQL query and return results.""" params = params or {} # Substitute non-bytes params in SQL using token-based approach - sql = self._substitute_params(sql, params) + sql = _substitute_params(sql, params) # Translate SQL to Redis command translated = self._translator.translate(sql) @@ -153,3 +160,69 @@ def execute(self, sql: str, *, params: dict | None = None) -> QueryResult: rows.append(row) return QueryResult(rows=rows, count=count) + + +class AsyncExecutor: + """Async version of Executor for use with redis.asyncio clients.""" + + def __init__( + self, + client: "async_redis.Redis", + schema_registry: AsyncSchemaRegistry, + ) -> None: + """Initialize async executor with Redis client and schema registry. + + Args: + client: An async Redis client (redis.asyncio.Redis). + schema_registry: An AsyncSchemaRegistry instance. + """ + self._client = client + self._schema_registry = schema_registry + self._translator = Translator(schema_registry) + + async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult: + """Execute a SQL query asynchronously and return results.""" + params = params or {} + + # Substitute non-bytes params in SQL + sql = _substitute_params(sql, params) + + # Translate SQL to Redis command (sync - no Redis calls) + translated = self._translator.translate(sql) + + # Build command list and substitute vector params + cmd: list[str | bytes] = list(translated.to_command_list()) + + # Find any bytes params (vectors) to substitute + vector_param: bytes | None = None + for value in params.values(): + if isinstance(value, bytes): + vector_param = value + break + + # Replace $vector placeholder with actual bytes + if vector_param: + for i, arg in enumerate(cmd): + if arg == "$vector": + cmd[i] = vector_param + + # Execute command asynchronously + raw_result = await self._client.execute_command(*cmd) + + # Parse result based on command type + count = raw_result[0] if raw_result else 0 + rows = [] + + if translated.command == "FT.SEARCH": + # FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...] + for i in range(2, len(raw_result), 2): + row_data = raw_result[i] + row = dict(zip(row_data[::2], row_data[1::2])) + rows.append(row) + else: + # FT.AGGREGATE format: [count, [fields1], [fields2], ...] + for row_data in raw_result[1:]: + row = dict(zip(row_data[::2], row_data[1::2])) + rows.append(row) + + return QueryResult(rows=rows, count=count) diff --git a/sql_redis/schema.py b/sql_redis/schema.py index f58555f..c5e57fc 100644 --- a/sql_redis/schema.py +++ b/sql_redis/schema.py @@ -2,10 +2,51 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING, Callable import redis +if TYPE_CHECKING: + import redis.asyncio as async_redis + + +def _parse_schema_from_info(info: list) -> dict[str, str]: + """Parse field types from FT.INFO response. + + This is a pure function with no I/O operations, shared by both + sync and async schema registries. + + Args: + info: The raw response from FT.INFO command. + + Returns: + Dictionary mapping field names to their types (e.g., {"title": "TEXT"}). + """ + schema = {} + # Find the 'attributes' section in the info response + for i, item in enumerate(info): + # Handle bytes or string comparison + item_str = item.decode("utf-8") if isinstance(item, bytes) else item + if item_str == "attributes": + attributes = info[i + 1] + for attr in attributes: + field_name = None + field_type = None + # Each attribute is a list like: + # [b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...] + for j, val in enumerate(attr): + val_str = val.decode("utf-8") if isinstance(val, bytes) else val + if val_str == "attribute" and j + 1 < len(attr): + fn = attr[j + 1] + field_name = fn.decode("utf-8") if isinstance(fn, bytes) else fn + if val_str == "type" and j + 1 < len(attr): + ft = attr[j + 1] + field_type = ft.decode("utf-8") if isinstance(ft, bytes) else ft + if field_name and field_type: + schema[field_name] = field_type + break + return schema + class SchemaRegistry: """Loads and caches index schemas from Redis. @@ -33,43 +74,12 @@ def _load_index_schema(self, index_name: str) -> None: """Load schema for a single index.""" try: info = self._client.execute_command("FT.INFO", index_name) - schema = self._parse_schema_from_info(info) + schema = _parse_schema_from_info(info) self._schemas[index_name] = schema except redis.ResponseError: # Index doesn't exist or was deleted self._schemas.pop(index_name, None) - def _parse_schema_from_info(self, info: list) -> dict[str, str]: - """Parse field types from FT.INFO response.""" - schema = {} - # Find the 'attributes' section in the info response - for i, item in enumerate(info): - # Handle bytes or string comparison - item_str = item.decode("utf-8") if isinstance(item, bytes) else item - if item_str == "attributes": - attributes = info[i + 1] - for attr in attributes: - field_name = None - field_type = None - # Each attribute is a list like: - # [b'identifier', b'title', b'attribute', b'title', b'type', b'TEXT', ...] - for j, val in enumerate(attr): - val_str = val.decode("utf-8") if isinstance(val, bytes) else val - if val_str == "attribute" and j + 1 < len(attr): - fn = attr[j + 1] - field_name = ( - fn.decode("utf-8") if isinstance(fn, bytes) else fn - ) - if val_str == "type" and j + 1 < len(attr): - ft = attr[j + 1] - field_type = ( - ft.decode("utf-8") if isinstance(ft, bytes) else ft - ) - if field_name and field_type: - schema[field_name] = field_type - break - return schema - def get_field_type(self, index: str, field: str) -> str | None: """Get field type for a given index and field. @@ -140,3 +150,66 @@ def process_pending_events(self) -> None: self._schemas.pop(idx, None) if self._on_change: self._on_change("dropped", idx) + + +class AsyncSchemaRegistry: + """Async version of SchemaRegistry for use with redis.asyncio clients. + + Loads and caches index schemas from Redis asynchronously. + """ + + def __init__(self, redis_client: "async_redis.Redis") -> None: + """Initialize with an async Redis client. + + Args: + redis_client: An async Redis client (redis.asyncio.Redis). + """ + self._client = redis_client + self._schemas: dict[str, dict[str, str]] = {} + + async def load_all(self) -> None: + """Load schemas for all indexes on the server. + + Uses asyncio.gather() to load all index schemas concurrently. + """ + import asyncio + + self._schemas.clear() + indexes = await self._client.execute_command("FT._LIST") + # Decode bytes to strings + decoded_indexes = [ + idx.decode("utf-8") if isinstance(idx, bytes) else idx for idx in indexes + ] + # Load all schemas concurrently + await asyncio.gather( + *[self._load_index_schema(name) for name in decoded_indexes] + ) + + async def _load_index_schema(self, index_name: str) -> None: + """Load schema for a single index.""" + try: + info = await self._client.execute_command("FT.INFO", index_name) + schema = _parse_schema_from_info(info) + self._schemas[index_name] = schema + except redis.ResponseError: + # Index doesn't exist or was deleted + self._schemas.pop(index_name, None) + + def get_field_type(self, index: str, field: str) -> str | None: + """Get field type for a given index and field. + + Returns None if index or field is unknown. + """ + schema = self._schemas.get(index, {}) + return schema.get(field) + + def get_schema(self, index: str) -> dict[str, str]: + """Get full schema for an index. + + Returns empty dict if index is unknown. + """ + return self._schemas.get(index, {}) + + async def refresh(self, index_name: str) -> None: + """Refresh schema for a single index.""" + await self._load_index_schema(index_name) diff --git a/sql_redis/translator.py b/sql_redis/translator.py index 0fec87c..e174413 100644 --- a/sql_redis/translator.py +++ b/sql_redis/translator.py @@ -7,7 +7,7 @@ from sql_redis.analyzer import AnalyzedQuery, Analyzer from sql_redis.parser import Condition, ParsedQuery, SQLParser from sql_redis.query_builder import QueryBuilder -from sql_redis.schema import SchemaRegistry +from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry @dataclass @@ -34,11 +34,13 @@ def to_command_string(self) -> str: class Translator: """Translates SQL queries to Redis FT.SEARCH/FT.AGGREGATE commands.""" - def __init__(self, schema_registry: SchemaRegistry): + def __init__(self, schema_registry: SchemaRegistry | AsyncSchemaRegistry) -> None: """Initialize translator with schema registry. Args: - schema_registry: Registry containing index schemas. + schema_registry: Registry containing index schemas. Can be either + sync (SchemaRegistry) or async (AsyncSchemaRegistry) - only + the sync get_schema() method is used. """ self._schema_registry = schema_registry self._parser = SQLParser() diff --git a/sql_redis/version.py b/sql_redis/version.py index 8a79f95..989a2be 100644 --- a/sql_redis/version.py +++ b/sql_redis/version.py @@ -2,7 +2,7 @@ from importlib.metadata import PackageNotFoundError, version except ImportError: # Python < 3.8 fallback - from importlib_metadata import PackageNotFoundError, version # type: ignore + from importlib_metadata import PackageNotFoundError, version # type: ignore # isort: skip try: __version__ = version("sql-redis") diff --git a/tests/test_async_executor.py b/tests/test_async_executor.py new file mode 100644 index 0000000..f126be7 --- /dev/null +++ b/tests/test_async_executor.py @@ -0,0 +1,344 @@ +"""Integration tests for async SQL executor. + +TDD: These tests define the expected behavior for AsyncSchemaRegistry and AsyncExecutor. +""" + +import struct + +import pytest +import redis.asyncio as async_redis +from testcontainers.redis import RedisContainer + +from sql_redis.executor import AsyncExecutor, QueryResult +from sql_redis.schema import AsyncSchemaRegistry + + +@pytest.fixture(scope="module") +def redis_container(): + """Start a Redis container for testing.""" + with RedisContainer("redis:8.0.2") as container: + yield container + + +@pytest.fixture +async def async_client(redis_container) -> async_redis.Redis: + """Create an async Redis client connected to the test container.""" + client = async_redis.Redis( + host=redis_container.get_container_host_ip(), + port=int(redis_container.get_exposed_port(6379)), + decode_responses=True, + ) + yield client + await client.aclose() + + +@pytest.fixture +async def products_index(async_client: async_redis.Redis) -> str: + """Create a products index with test data.""" + index_name = "async_products" + try: + await async_client.execute_command("FT.DROPINDEX", index_name, "DD") + except Exception: + pass + + await async_client.execute_command( + "FT.CREATE", + index_name, + "ON", + "HASH", + "PREFIX", + "1", + "async_product:", + "SCHEMA", + "title", + "TEXT", + "category", + "TAG", + "price", + "NUMERIC", + "stock", + "NUMERIC", + ) + + # Add test data + await async_client.hset( + "async_product:1", + mapping={ + "title": "Laptop Pro", + "category": "electronics", + "price": "999.99", + "stock": "10", + }, + ) + await async_client.hset( + "async_product:2", + mapping={ + "title": "Wireless Mouse", + "category": "electronics", + "price": "29.99", + "stock": "50", + }, + ) + await async_client.hset( + "async_product:3", + mapping={ + "title": "Python Book", + "category": "books", + "price": "49.99", + "stock": "25", + }, + ) + await async_client.hset( + "async_product:4", + mapping={ + "title": "Redis Guide", + "category": "books", + "price": "39.99", + "stock": "15", + }, + ) + + yield index_name + + # Cleanup + try: + await async_client.execute_command("FT.DROPINDEX", index_name, "DD") + except Exception: + pass + + +@pytest.fixture +async def async_executor( + async_client: async_redis.Redis, products_index: str +) -> AsyncExecutor: + """Create an async executor with the products index loaded.""" + registry = AsyncSchemaRegistry(async_client) + await registry.load_all() + return AsyncExecutor(async_client, registry) + + +class TestAsyncSchemaRegistry: + """Tests for AsyncSchemaRegistry.""" + + async def test_load_all_loads_indexes( + self, async_client: async_redis.Redis, products_index: str + ): + """load_all() should load index schemas from Redis.""" + registry = AsyncSchemaRegistry(async_client) + await registry.load_all() + + schema = registry.get_schema(products_index) + assert schema is not None + assert "title" in schema + assert schema["title"] == "TEXT" + assert "category" in schema + assert schema["category"] == "TAG" + assert "price" in schema + assert schema["price"] == "NUMERIC" + + async def test_get_schema_returns_empty_for_unknown( + self, async_client: async_redis.Redis + ): + """get_schema() returns empty dict for unknown index.""" + registry = AsyncSchemaRegistry(async_client) + await registry.load_all() + + schema = registry.get_schema("nonexistent_index") + assert schema == {} + + +class TestAsyncExecutorBasic: + """Tests for basic async query execution.""" + + async def test_select_all(self, async_executor: AsyncExecutor, products_index: str): + """SELECT * returns all documents.""" + result = await async_executor.execute(f"SELECT * FROM {products_index}") + assert result.count == 4 + assert len(result.rows) == 4 + + async def test_result_is_query_result( + self, async_executor: AsyncExecutor, products_index: str + ): + """Result should be a QueryResult instance.""" + result = await async_executor.execute(f"SELECT * FROM {products_index}") + assert isinstance(result, QueryResult) + assert hasattr(result, "rows") + assert hasattr(result, "count") + + async def test_select_with_tag_filter( + self, async_executor: AsyncExecutor, products_index: str + ): + """SELECT with tag filter.""" + result = await async_executor.execute( + f"SELECT * FROM {products_index} WHERE category = 'books'" + ) + assert result.count == 2 + for row in result.rows: + assert row["category"] == "books" + + async def test_select_with_numeric_filter( + self, async_executor: AsyncExecutor, products_index: str + ): + """SELECT with numeric comparison.""" + result = await async_executor.execute( + f"SELECT * FROM {products_index} WHERE price < 50" + ) + assert result.count >= 2 + for row in result.rows: + assert float(row["price"]) < 50 + + async def test_select_with_limit( + self, async_executor: AsyncExecutor, products_index: str + ): + """SELECT with LIMIT.""" + result = await async_executor.execute(f"SELECT * FROM {products_index} LIMIT 2") + assert len(result.rows) == 2 + + async def test_select_with_order_by( + self, async_executor: AsyncExecutor, products_index: str + ): + """SELECT with ORDER BY.""" + result = await async_executor.execute( + f"SELECT * FROM {products_index} ORDER BY price DESC" + ) + prices = [float(row["price"]) for row in result.rows] + assert prices == sorted(prices, reverse=True) + + +class TestAsyncExecutorAggregation: + """Tests for async aggregate query execution.""" + + async def test_count_all(self, async_executor: AsyncExecutor, products_index: str): + """SELECT COUNT(*) returns count.""" + result = await async_executor.execute(f"SELECT COUNT(*) FROM {products_index}") + assert len(result.rows) == 1 + row = result.rows[0] + count_value = row.get("COUNT(*)", row.get("count", None)) + assert count_value is not None + + async def test_group_by_with_count( + self, async_executor: AsyncExecutor, products_index: str + ): + """SELECT with GROUP BY and COUNT.""" + result = await async_executor.execute( + f"SELECT category, COUNT(*) as cnt FROM {products_index} GROUP BY category" + ) + assert len(result.rows) == 2 # electronics and books + categories = {row["category"] for row in result.rows} + assert categories == {"electronics", "books"} + + +class TestAsyncExecutorParams: + """Tests for parameterized async execution.""" + + async def test_numeric_param( + self, async_executor: AsyncExecutor, products_index: str + ): + """Execute with numeric parameter.""" + result = await async_executor.execute( + f"SELECT * FROM {products_index} WHERE price > :min_price", + params={"min_price": 40}, + ) + for row in result.rows: + assert float(row["price"]) > 40 + + async def test_string_param( + self, async_executor: AsyncExecutor, products_index: str + ): + """Execute with string parameter.""" + result = await async_executor.execute( + f"SELECT * FROM {products_index} WHERE category = :cat", + params={"cat": "books"}, + ) + assert len(result.rows) == 2 + for row in result.rows: + assert row["category"] == "books" + + +class TestAsyncVectorSearch: + """Tests for async vector search execution.""" + + @pytest.fixture + async def vector_index(self, async_client: async_redis.Redis) -> str: + """Create a vector index with test data.""" + index_name = "async_vectors" + try: + await async_client.execute_command("FT.DROPINDEX", index_name, "DD") + except Exception: + pass + + await async_client.execute_command( + "FT.CREATE", + index_name, + "ON", + "HASH", + "PREFIX", + "1", + "async_vec:", + "SCHEMA", + "title", + "TEXT", + "embedding", + "VECTOR", + "HNSW", + "6", + "TYPE", + "FLOAT32", + "DIM", + "4", + "DISTANCE_METRIC", + "COSINE", + ) + + def to_bytes(v): + return struct.pack(f"{len(v)}f", *v) + + # Use a separate non-decode client for binary data + raw_client = async_redis.Redis( + host=async_client.connection_pool.connection_kwargs["host"], + port=async_client.connection_pool.connection_kwargs["port"], + decode_responses=False, + ) + await raw_client.hset( + "async_vec:1", + mapping={"title": "First", "embedding": to_bytes([0.1, 0.2, 0.3, 0.4])}, + ) + await raw_client.hset( + "async_vec:2", + mapping={"title": "Second", "embedding": to_bytes([0.5, 0.6, 0.7, 0.8])}, + ) + await raw_client.hset( + "async_vec:3", + mapping={"title": "Third", "embedding": to_bytes([0.9, 0.8, 0.7, 0.6])}, + ) + await raw_client.aclose() + + yield index_name + + # Cleanup + try: + await async_client.execute_command("FT.DROPINDEX", index_name, "DD") + except Exception: + pass + + async def test_vector_search_with_param( + self, async_client: async_redis.Redis, vector_index: str + ): + """Vector search with vector parameter.""" + registry = AsyncSchemaRegistry(async_client) + await registry.load_all() + executor = AsyncExecutor(async_client, registry) + + query_vector = struct.pack("4f", 0.1, 0.2, 0.3, 0.4) + result = await executor.execute( + f"SELECT title, vector_distance(embedding, :vec) AS score " + f"FROM {vector_index} LIMIT 3", + params={"vec": query_vector}, + ) + assert len(result.rows) <= 3 + # First result should be closest to query vector + assert result.rows[0]["title"] == "First" + # Verify vector distance score is returned + assert "score" in result.rows[0] + score = float(result.rows[0]["score"]) + assert score >= 0 # Distance should be non-negative diff --git a/tests/test_schema_registry.py b/tests/test_schema_registry.py index 498785e..141e3c0 100644 --- a/tests/test_schema_registry.py +++ b/tests/test_schema_registry.py @@ -3,7 +3,7 @@ import pytest import redis -from sql_redis.schema import SchemaRegistry +from sql_redis.schema import SchemaRegistry, _parse_schema_from_info def _create_test_indexes(redis_client: redis.Redis) -> list[str]: @@ -222,20 +222,16 @@ def test_load_all_handles_no_indexes(self, redis_client: redis.Redis): class TestSchemaRegistryParsing: """Tests for schema parsing edge cases.""" - def test_parse_schema_no_attributes_section(self, redis_client: redis.Redis): + def test_parse_schema_no_attributes_section(self): """_parse_schema_from_info handles response without attributes.""" - registry = SchemaRegistry(redis_client) - # FT.INFO response without 'attributes' key fake_info = ["index_name", "test", "other_key", "value"] - schema = registry._parse_schema_from_info(fake_info) + schema = _parse_schema_from_info(fake_info) assert schema == {} - def test_parse_schema_incomplete_attribute(self, redis_client: redis.Redis): + def test_parse_schema_incomplete_attribute(self): """_parse_schema_from_info handles attribute without type.""" - registry = SchemaRegistry(redis_client) - # FT.INFO response with attribute but missing type fake_info = [ "attributes", @@ -244,7 +240,7 @@ def test_parse_schema_incomplete_attribute(self, redis_client: redis.Redis): ["identifier", "field2", "attribute", "field2", "type", "TEXT"], ], ] - schema = registry._parse_schema_from_info(fake_info) + schema = _parse_schema_from_info(fake_info) # Only field2 should be captured (field1 has no type) assert schema == {"field2": "TEXT"}