Skip to content

Commit c3d3dec

Browse files
authored
fix: BI-6846 Patch for TrinoSQLCompiler window functions (#1421)
* patch compile_ignore_nulls * add test * refactor test * fix trino version * move patch to package root
1 parent e6ebf45 commit c3d3dec

5 files changed

Lines changed: 32 additions & 2 deletions

File tree

lib/dl_api_lib_testing/dl_api_lib_testing/connector/complex_queries.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def test_window_functions(
551551
"Date Sales": "SUM([Group Sales] WITHIN [order_date])",
552552
"City Sales": "SUM([Group Sales] AMONG [order_date])",
553553
"Total RSUM": 'RSUM([Group Sales], "asc" TOTAL)',
554+
"First Round Count": "FIRST(ROUND(COUNT(1)))",
554555
},
555556
)
556557

@@ -567,6 +568,7 @@ def test_window_functions(
567568
ds.find_field(title="Date Sales"),
568569
ds.find_field(title="City Sales"),
569570
ds.find_field(title="Total RSUM"),
571+
ds.find_field(title="First Round Count"),
570572
],
571573
order_by=[
572574
ds.find_field(title="order_date"),
@@ -585,7 +587,7 @@ def test_window_functions(
585587
assert {row[3] for row in data_rows}.issubset({str(i) for i in range(1, cnt + 1)})
586588

587589
# There are as many [Unique Rank of Sales] values as there are rows
588-
assert {row[4] for row in data_rows} == ({str(i) for i in range(1, cnt + 1)})
590+
assert {row[4] for row in data_rows} == {str(i) for i in range(1, cnt + 1)}
589591

590592
# [Rank of City Sales for Date] values are not greater than the number of [City] values
591593
assert len({row[5] for row in data_rows}) <= len({row[1] for row in data_rows})
@@ -603,6 +605,8 @@ def test_window_functions(
603605
# RSUM = previous RSUM value + value of current arg
604606
assert pytest.approx(float(data_rows[i][9])) == float(data_rows[i - 1][9]) + float(data_rows[i][2])
605607

608+
assert all(float(row[10]) == 1 for row in data_rows)
609+
606610

607611
class DefaultBasicNativeFunctionTestSuite(
608612
RegulatedTestCase, DataApiTestBase, DatasetTestBase, DbServiceFixtureTextClass
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from dl_connector_trino import vendor_patches # noqa: F401
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Any
2+
3+
from sqlalchemy.ext.compiler import compiles
4+
from sqlalchemy.sql.compiler import SQLCompiler
5+
from trino.sqlalchemy.compiler import TrinoSQLCompiler
6+
7+
8+
# This is a temporary patch to fix https://github.com/trinodb/trino-python-client/pull/586
9+
# BI-6846
10+
@compiles(TrinoSQLCompiler.FirstValue)
11+
@compiles(TrinoSQLCompiler.LastValue)
12+
@compiles(TrinoSQLCompiler.NthValue)
13+
@compiles(TrinoSQLCompiler.Lead)
14+
@compiles(TrinoSQLCompiler.Lag)
15+
def compile_ignore_nulls(
16+
element: TrinoSQLCompiler.GenericIgnoreNulls,
17+
compiler: SQLCompiler,
18+
**kwargs: Any,
19+
) -> str:
20+
compiled = f"{element.name}({compiler.process(element.clauses, **kwargs)})"
21+
if element.ignore_nulls:
22+
compiled += " IGNORE NULLS"
23+
return compiled

metapkg/poetry.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

metapkg/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ sqlalchemy = "==1.4.46, <2.0"
8383
sqlalchemy-bigquery = "==1.9.0"
8484
tabulate = "==0.9.0"
8585
tornado = "==6.4.2"
86+
trino = {extras = ["sqlalchemy"], version = "==0.331.0"}
8687
typeguard = "==4.1.5"
8788
typing-extensions = "==4.15.0"
8889
ujson = "==1.35"

0 commit comments

Comments
 (0)