Skip to content

Commit 4722770

Browse files
Add Athena filter() function support to SQLAlchemy dialect
Implements support for Amazon Athena's filter() function with lambda expressions in PyAthena's SQLAlchemy dialect, addressing issue #480. Features: - Basic filter() function compilation with lambda expressions - Support for complex lambda conditions and nested field access - Comprehensive error handling for invalid argument counts - Type-safe implementation using isinstance checks - Full test coverage with 7 test cases Examples: - filter(array_col, 'x -> x > 0') - filter(data_col, 'x -> x["field"] > value') - count(filter(action_col, 'x -> x["timestamp"] BETWEEN dates')) 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 29cb19e commit 4722770

2 files changed

Lines changed: 137 additions & 1 deletion

File tree

pyathena/sqlalchemy/compiler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
IdentifierPreparer,
1111
SQLCompiler,
1212
)
13+
from sqlalchemy.sql.elements import BindParameter
1314
from sqlalchemy.sql.schema import Column
1415

1516
from pyathena.model import (
@@ -178,6 +179,33 @@ class AthenaStatementCompiler(SQLCompiler):
178179
def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw):
179180
return f"length{self.function_argspec(fn, **kw)}"
180181

182+
def visit_filter_func(self, fn: "FunctionElement[Any]", **kw) -> str:
183+
"""Compile Athena filter() function with lambda expressions.
184+
185+
Supports syntax: filter(array_expr, lambda_expr)
186+
Example: filter(ARRAY[1, 2, 3], x -> x > 1)
187+
"""
188+
if len(fn.clauses.clauses) != 2:
189+
raise exc.CompileError(
190+
f"filter() function expects exactly 2 arguments, got {len(fn.clauses.clauses)}"
191+
)
192+
193+
array_expr = fn.clauses.clauses[0]
194+
lambda_expr = fn.clauses.clauses[1]
195+
196+
# Process the array expression normally
197+
array_sql = self.process(array_expr, **kw)
198+
199+
# Process lambda expression - handle string literals as lambda expressions
200+
if isinstance(lambda_expr, BindParameter) and isinstance(lambda_expr.value, str):
201+
# Handle string literal lambda expressions like 'x -> x > 0'
202+
lambda_sql = lambda_expr.value
203+
else:
204+
# Process as regular SQL expression
205+
lambda_sql = self.process(lambda_expr, **kw)
206+
207+
return f"filter({array_sql}, {lambda_sql})"
208+
181209
def visit_cast(self, cast: "Cast[Any]", **kwargs):
182210
if (isinstance(cast.type, types.VARCHAR) and cast.type.length is None) or isinstance(
183211
cast.type, types.String

tests/pyathena/sqlalchemy/test_compiler.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
from unittest.mock import Mock
44

5-
from sqlalchemy import Integer, String
5+
import pytest
6+
from sqlalchemy import Column, Integer, MetaData, String, Table, exc, func, select
7+
from sqlalchemy.sql import literal
68

9+
from pyathena.sqlalchemy.base import AthenaDialect
710
from pyathena.sqlalchemy.compiler import AthenaTypeCompiler
811
from pyathena.sqlalchemy.types import ARRAY, MAP, STRUCT, AthenaArray, AthenaMap, AthenaStruct
12+
from tests import ENV
913

1014

1115
class TestAthenaTypeCompiler:
@@ -109,3 +113,107 @@ def test_visit_array_no_attributes(self):
109113
array_type = type("MockArray", (), {})()
110114
result = compiler.visit_array(array_type)
111115
assert result == "ARRAY<STRING>"
116+
117+
118+
class TestAthenaStatementCompiler:
119+
"""Test cases for Athena statement compiler functionality."""
120+
121+
def setup_method(self):
122+
"""Set up test fixtures."""
123+
self.dialect = AthenaDialect()
124+
self.metadata = MetaData(schema=ENV.schema)
125+
self.test_table = Table(
126+
"test_athena_statement_compiler",
127+
self.metadata,
128+
Column("id", Integer),
129+
Column("data", ARRAY(String)),
130+
Column("numbers", ARRAY(Integer)),
131+
)
132+
133+
def test_visit_filter_func_basic(self):
134+
"""Test basic filter() function compilation."""
135+
# Test basic filter with string lambda expression
136+
stmt = select(func.filter(self.test_table.c.numbers, literal("x -> x > 0")))
137+
compiled = stmt.compile(dialect=self.dialect)
138+
139+
sql_str = str(compiled)
140+
assert "filter(" in sql_str
141+
assert "x -> x > 0" in sql_str
142+
143+
def test_visit_filter_func_array_literal(self):
144+
"""Test filter() function with array literal."""
145+
# Test filter with array literal - using ARRAY constructor
146+
stmt = select(
147+
func.filter(
148+
func.array(literal(1), literal(2), literal(3), literal(-1)), literal("x -> x > 0")
149+
)
150+
)
151+
compiled = stmt.compile(dialect=self.dialect)
152+
153+
sql_str = str(compiled)
154+
assert "filter(" in sql_str
155+
assert "x -> x > 0" in sql_str
156+
157+
def test_visit_filter_func_complex_lambda(self):
158+
"""Test filter() function with complex lambda expression."""
159+
# Test complex lambda expression
160+
complex_lambda = literal("x -> x IS NOT NULL AND x > 5")
161+
stmt = select(func.filter(self.test_table.c.numbers, complex_lambda))
162+
compiled = stmt.compile(dialect=self.dialect)
163+
164+
sql_str = str(compiled)
165+
assert "filter(" in sql_str
166+
assert "x -> x IS NOT NULL AND x > 5" in sql_str
167+
168+
def test_visit_filter_func_nested_access(self):
169+
"""Test filter() function with nested field access."""
170+
# Test lambda with nested field access (for complex types)
171+
nested_lambda = literal("x -> x['timestamp'] > '2023-01-01'")
172+
stmt = select(func.filter(self.test_table.c.data, nested_lambda))
173+
compiled = stmt.compile(dialect=self.dialect)
174+
175+
sql_str = str(compiled)
176+
assert "filter(" in sql_str
177+
assert "x -> x['timestamp'] > '2023-01-01'" in sql_str
178+
179+
def test_visit_filter_func_wrong_argument_count(self):
180+
"""Test filter() function with wrong number of arguments."""
181+
# Test error when wrong number of arguments provided
182+
with pytest.raises(
183+
exc.CompileError, match="filter\\(\\) function expects exactly 2 arguments"
184+
):
185+
stmt = select(func.filter(self.test_table.c.numbers))
186+
stmt.compile(dialect=self.dialect)
187+
188+
with pytest.raises(
189+
exc.CompileError, match="filter\\(\\) function expects exactly 2 arguments"
190+
):
191+
stmt = select(
192+
func.filter(self.test_table.c.numbers, literal("x -> x > 0"), literal("extra_arg"))
193+
)
194+
stmt.compile(dialect=self.dialect)
195+
196+
def test_visit_filter_func_integration_example(self):
197+
"""Test filter() function with the original issue example."""
198+
# Test the example from the GitHub issue
199+
lambda_expr = literal(
200+
"x -> x['timestamp'] <= '2023-10-10' AND x['timestamp'] >= '2023-10-01' "
201+
"AND x['action_count'] >= 2"
202+
)
203+
stmt = select(func.count(func.filter(self.test_table.c.data, lambda_expr)))
204+
compiled = stmt.compile(dialect=self.dialect)
205+
206+
sql_str = str(compiled)
207+
assert "count(" in sql_str
208+
assert "filter(" in sql_str
209+
assert "x -> x['timestamp'] <= '2023-10-10'" in sql_str
210+
assert "x['action_count'] >= 2" in sql_str
211+
212+
def test_visit_char_length_func_existing(self):
213+
"""Test existing char_length function still works."""
214+
# Ensure existing functionality isn't broken
215+
stmt = select(func.char_length(self.test_table.c.data))
216+
compiled = stmt.compile(dialect=self.dialect)
217+
218+
sql_str = str(compiled)
219+
assert "length(" in sql_str

0 commit comments

Comments
 (0)