Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion sql_redis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
"""SQL to Redis command translation utility."""

from sql_redis.executor import AsyncExecutor, Executor, QueryResult
from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
from sql_redis.translator import TranslatedQuery, Translator
from sql_redis.version import __version__

__all__ = ["Translator", "TranslatedQuery", "__version__"]
__all__ = [
"Translator",
"TranslatedQuery",
"SchemaRegistry",
"AsyncSchemaRegistry",
"Executor",
"AsyncExecutor",
"QueryResult",
"__version__",
]
233 changes: 153 additions & 80 deletions sql_redis/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,96 @@

import re
from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING, Any

import redis

from sql_redis.schema import SchemaRegistry
from sql_redis.schema import AsyncSchemaRegistry, SchemaRegistry
from sql_redis.translator import Translator

if TYPE_CHECKING:
import redis.asyncio as async_redis


def _substitute_params(sql: str, params: dict[str, Any]) -> str:
"""Substitute parameter placeholders in SQL with actual values.

This is a pure function with no I/O operations, shared by both
sync and async executors.

Uses token-based approach: splits SQL on :param patterns, then rebuilds
with substituted values. This approach solves two critical bugs:

1. PARTIAL MATCHING BUG: Prevents :id from matching inside :product_id
by treating each :identifier as a complete token

2. QUOTE ESCAPING BUG: Properly escapes single quotes in string values
using SQL standard (single quote -> double single quote)

Args:
sql: The SQL string with :param placeholders.
params: Dictionary mapping parameter names to values.

Returns:
SQL string with parameters substituted.

Implementation Details:
- Uses regex to split on parameter patterns: :[a-zA-Z_][a-zA-Z0-9_]*
- Keeps delimiters (the :param tokens) in the split result
- Iterates through tokens, substituting matched parameters
- String values are wrapped in single quotes with proper escaping
- Numeric values are converted to strings
- Bytes values (e.g., vectors) are NOT substituted here

Known Limitations:
- Colons in string literals: SQL like "WHERE x = 'test:value'" would
theoretically match :value as a parameter. However, this is not a
practical issue because:
1. Users pass values via parameters, not hardcoded in SQL
2. The translator has its own handling of string literals
3. No real-world use cases have been identified
- Parameter names are case-sensitive (:id != :ID)
- Only handles int, float, str types; other types keep placeholder
"""
if not params:
return sql

# Split SQL on :param patterns, keeping the delimiters
# Pattern matches : followed by valid identifier:
# [a-zA-Z_] - First char must be letter or underscore
# [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore
# This prevents partial matching: :id and :product_id are separate tokens
tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql)

result = []
for token in tokens:
if token.startswith(":"):
# This is a parameter placeholder
key = token[1:] # Remove leading :
if key in params:
value = params[key]
if isinstance(value, (int, float)):
# Numeric values: convert to string
result.append(str(value))
elif isinstance(value, str):
# String values: wrap in quotes and escape single quotes
# SQL standard: ' -> '' (double single quote)
# This fixes the quote escaping bug
escaped = value.replace("'", "''")
result.append(f"'{escaped}'")
else:
# Other types (bytes, None, bool, list, etc.):
# Keep placeholder as-is (handled elsewhere or unsupported)
Comment on lines +74 to +86

Copilot AI Mar 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_substitute_params treats bool as numeric because bool is a subclass of int, so params like {"flag": True} will be substituted as "True"/"False" even though the docstring says bool should be left as a placeholder/unsupported. Consider checking for bool explicitly before the (int, float) branch (or rejecting unsupported types) so behavior matches the documented limitations and avoids generating invalid SQL.

Copilot uses AI. Check for mistakes.
result.append(token)
else:
# Parameter not provided: keep placeholder as-is
result.append(token)
else:
# Not a parameter: keep as-is
result.append(token)

return "".join(result)


Comment on lines +61 to 97

Copilot AI Mar 2, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_substitute_params uses a regex split that will also match ":param" tokens inside SQL string literals/comments (e.g., WHERE x = 'test:val'), which can lead to unintended substitution and changed query semantics. If this is meant to be safe for arbitrary SQL, consider parsing tokens with sqlglot (or at least skipping over quoted strings) instead of regex-based splitting.

Suggested change
# Split SQL on :param patterns, keeping the delimiters
# Pattern matches : followed by valid identifier:
# [a-zA-Z_] - First char must be letter or underscore
# [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore
# This prevents partial matching: :id and :product_id are separate tokens
tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql)
result = []
for token in tokens:
if token.startswith(":"):
# This is a parameter placeholder
key = token[1:] # Remove leading :
if key in params:
value = params[key]
if isinstance(value, (int, float)):
# Numeric values: convert to string
result.append(str(value))
elif isinstance(value, str):
# String values: wrap in quotes and escape single quotes
# SQL standard: ' -> '' (double single quote)
# This fixes the quote escaping bug
escaped = value.replace("'", "''")
result.append(f"'{escaped}'")
else:
# Other types (bytes, None, bool, list, etc.):
# Keep placeholder as-is (handled elsewhere or unsupported)
result.append(token)
else:
# Parameter not provided: keep placeholder as-is
result.append(token)
else:
# Not a parameter: keep as-is
result.append(token)
return "".join(result)
# Compile parameter pattern once for reuse
param_pattern = re.compile(r"(:[a-zA-Z_][a-zA-Z0-9_]*)")
def substitute_in_segment(segment: str) -> str:
"""Apply parameter substitution to a segment that is known to contain
no string literals or comments.
"""
if not segment:
return segment
# Split segment on :param patterns, keeping the delimiters
# Pattern matches : followed by valid identifier:
# [a-zA-Z_] - First char must be letter or underscore
# [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore
# This prevents partial matching: :id and :product_id are separate tokens
tokens = param_pattern.split(segment)
result_parts: list[str] = []
for token in tokens:
if token.startswith(":"):
# This is a parameter placeholder
key = token[1:] # Remove leading :
if key in params:
value = params[key]
if isinstance(value, (int, float)):
# Numeric values: convert to string
result_parts.append(str(value))
elif isinstance(value, str):
# String values: wrap in quotes and escape single quotes
# SQL standard: ' -> '' (double single quote)
# This fixes the quote escaping bug
escaped = value.replace("'", "''")
result_parts.append(f"'{escaped}'")
else:
# Other types (bytes, None, bool, list, etc.):
# Keep placeholder as-is (handled elsewhere or unsupported)
result_parts.append(token)
else:
# Parameter not provided: keep placeholder as-is
result_parts.append(token)
else:
# Not a parameter: keep as-is
result_parts.append(token)
return "".join(result_parts)
# Regex to match regions where we MUST NOT substitute parameters:
# - Single-quoted string literals (with '' escapes)
# - Double-quoted identifiers (with "" escapes)
# - Single-line comments starting with --
# - Block comments delimited by /* ... */
skip_pattern = re.compile(
r"""
(?:'([^']|'')*') # single-quoted string
|(?:"([^"]|"")*") # double-quoted identifier/string
|(?:--[^\n]*\n?) # single-line comment
|(?:/\*.*?\*/) # block comment
""",
re.DOTALL | re.VERBOSE,
)
final_parts: list[str] = []
last_index = 0
# Walk through the SQL, substituting only in non-literal, non-comment segments
for match in skip_pattern.finditer(sql):
start, end = match.start(), match.end()
# Process code before the literal/comment
if start > last_index:
code_segment = sql[last_index:start]
final_parts.append(substitute_in_segment(code_segment))
# Append the literal/comment unchanged
final_parts.append(match.group(0))
last_index = end
# Process any trailing code after the last literal/comment
if last_index < len(sql):
final_parts.append(substitute_in_segment(sql[last_index:]))
return "".join(final_parts)

Copilot uses AI. Check for mistakes.
@dataclass
class QueryResult:
Expand All @@ -23,94 +106,18 @@ class QueryResult:
class Executor:
"""Executes SQL queries against Redis."""

def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry):
def __init__(self, client: redis.Redis, schema_registry: SchemaRegistry) -> None:
"""Initialize executor with Redis client and schema registry."""
self._client = client
self._schema_registry = schema_registry
self._translator = Translator(schema_registry)

def _substitute_params(self, sql: str, params: dict[str, Any]) -> str:
"""Substitute parameter placeholders in SQL with actual values.

Uses token-based approach: splits SQL on :param patterns, then rebuilds
with substituted values. This approach solves two critical bugs:

1. PARTIAL MATCHING BUG: Prevents :id from matching inside :product_id
by treating each :identifier as a complete token

2. QUOTE ESCAPING BUG: Properly escapes single quotes in string values
using SQL standard (single quote -> double single quote)

Args:
sql: The SQL string with :param placeholders.
params: Dictionary mapping parameter names to values.

Returns:
SQL string with parameters substituted.

Implementation Details:
- Uses regex to split on parameter patterns: :[a-zA-Z_][a-zA-Z0-9_]*
- Keeps delimiters (the :param tokens) in the split result
- Iterates through tokens, substituting matched parameters
- String values are wrapped in single quotes with proper escaping
- Numeric values are converted to strings
- Bytes values (e.g., vectors) are NOT substituted here

Known Limitations:
- Colons in string literals: SQL like "WHERE x = 'test:value'" would
theoretically match :value as a parameter. However, this is not a
practical issue because:
1. Users pass values via parameters, not hardcoded in SQL
2. The translator has its own handling of string literals
3. No real-world use cases have been identified
- Parameter names are case-sensitive (:id != :ID)
- Only handles int, float, str types; other types keep placeholder
"""
if not params:
return sql

# Split SQL on :param patterns, keeping the delimiters
# Pattern matches : followed by valid identifier:
# [a-zA-Z_] - First char must be letter or underscore
# [a-zA-Z0-9_]* - Subsequent chars can be alphanumeric or underscore
# This prevents partial matching: :id and :product_id are separate tokens
tokens = re.split(r"(:[a-zA-Z_][a-zA-Z0-9_]*)", sql)

result = []
for token in tokens:
if token.startswith(":"):
# This is a parameter placeholder
key = token[1:] # Remove leading :
if key in params:
value = params[key]
if isinstance(value, (int, float)):
# Numeric values: convert to string
result.append(str(value))
elif isinstance(value, str):
# String values: wrap in quotes and escape single quotes
# SQL standard: ' -> '' (double single quote)
# This fixes the quote escaping bug
escaped = value.replace("'", "''")
result.append(f"'{escaped}'")
else:
# Other types (bytes, None, bool, list, etc.):
# Keep placeholder as-is (handled elsewhere or unsupported)
result.append(token)
else:
# Parameter not provided: keep placeholder as-is
result.append(token)
else:
# Not a parameter: keep as-is
result.append(token)

return "".join(result)

def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
"""Execute a SQL query and return results."""
params = params or {}

# Substitute non-bytes params in SQL using token-based approach
sql = self._substitute_params(sql, params)
sql = _substitute_params(sql, params)

# Translate SQL to Redis command
translated = self._translator.translate(sql)
Expand Down Expand Up @@ -153,3 +160,69 @@ def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
rows.append(row)

return QueryResult(rows=rows, count=count)


class AsyncExecutor:
"""Async version of Executor for use with redis.asyncio clients."""

def __init__(
self,
client: "async_redis.Redis",
schema_registry: AsyncSchemaRegistry,
) -> None:
"""Initialize async executor with Redis client and schema registry.

Args:
client: An async Redis client (redis.asyncio.Redis).
schema_registry: An AsyncSchemaRegistry instance.
"""
self._client = client
self._schema_registry = schema_registry
self._translator = Translator(schema_registry)

async def execute(self, sql: str, *, params: dict | None = None) -> QueryResult:
"""Execute a SQL query asynchronously and return results."""
params = params or {}

# Substitute non-bytes params in SQL
sql = _substitute_params(sql, params)

# Translate SQL to Redis command (sync - no Redis calls)
translated = self._translator.translate(sql)

# Build command list and substitute vector params
cmd: list[str | bytes] = list(translated.to_command_list())

# Find any bytes params (vectors) to substitute
vector_param: bytes | None = None
for value in params.values():
if isinstance(value, bytes):
vector_param = value
break

# Replace $vector placeholder with actual bytes
if vector_param:
for i, arg in enumerate(cmd):
if arg == "$vector":
cmd[i] = vector_param

# Execute command asynchronously
raw_result = await self._client.execute_command(*cmd)

# Parse result based on command type
count = raw_result[0] if raw_result else 0
rows = []

if translated.command == "FT.SEARCH":
# FT.SEARCH format: [count, key1, [fields1], key2, [fields2], ...]
for i in range(2, len(raw_result), 2):
row_data = raw_result[i]
row = dict(zip(row_data[::2], row_data[1::2]))
rows.append(row)
else:
# FT.AGGREGATE format: [count, [fields1], [fields2], ...]
for row_data in raw_result[1:]:
row = dict(zip(row_data[::2], row_data[1::2]))
rows.append(row)

return QueryResult(rows=rows, count=count)
Loading