Skip to content

Commit 9bf0b0b

Browse files
Add integration test for filter() function
Add test_filter_func to verify actual Athena query execution with filter() function. Tests three scenarios: - Basic filtering: filter(array, 'x -> x > 1') returns [2] from [1, 2] - All values match: filter(array, 'x -> x > 0') returns [1, 2] - No matches: filter(array, 'x -> x > 10') returns [] 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 4722770 commit 9bf0b0b

1 file changed

Lines changed: 38 additions & 0 deletions

File tree

tests/pyathena/sqlalchemy/test_base.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,44 @@ def test_char_length(self, engine):
234234
).scalar()
235235
assert result == len("a string")
236236

237+
def test_filter_func(self, engine):
238+
engine, conn = engine
239+
one_row_complex = Table("one_row_complex", MetaData(schema=ENV.schema), autoload_with=conn)
240+
241+
# Test filter() function with array column
242+
# The col_array contains [1, 2] based on the test data above
243+
result = conn.execute(
244+
sqlalchemy.select(
245+
sqlalchemy.func.filter(
246+
one_row_complex.c.col_array, sqlalchemy.literal("x -> x > 1")
247+
)
248+
)
249+
).scalar()
250+
# Should return [2] since only 2 is greater than 1
251+
assert result == [2]
252+
253+
# Test filter() function with different condition
254+
result_all = conn.execute(
255+
sqlalchemy.select(
256+
sqlalchemy.func.filter(
257+
one_row_complex.c.col_array, sqlalchemy.literal("x -> x > 0")
258+
)
259+
)
260+
).scalar()
261+
# Should return [1, 2] since both values are greater than 0
262+
assert result_all == [1, 2]
263+
264+
# Test filter() function with no matches
265+
result_empty = conn.execute(
266+
sqlalchemy.select(
267+
sqlalchemy.func.filter(
268+
one_row_complex.c.col_array, sqlalchemy.literal("x -> x > 10")
269+
)
270+
)
271+
).scalar()
272+
# Should return empty array since no values are greater than 10
273+
assert result_empty == []
274+
237275
def test_reflect_select(self, engine):
238276
engine, conn = engine
239277
one_row_complex = Table("one_row_complex", MetaData(schema=ENV.schema), autoload_with=conn)

0 commit comments

Comments
 (0)