Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_create_temp_file_format,
build_location_helper,
)
import snowflake.snowpark.context as context
from snowflake.snowpark._internal.analyzer.binary_plan_node import (
AsOf,
Except,
Expand Down Expand Up @@ -689,16 +690,18 @@ def aggregate_statement(
) -> str:
# add limit 1 because aggregate may be on non-aggregate function in a scalar aggregation
# for example, df.agg(lit(1))
Comment on lines 691 to 692

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this assumption legit in the context of SCOS? what cases will be broken if we always remove limit_expression(1) for SCOS?

given the complexity of the implementation (extra tree walk, code maintainability, etc.) and the minimal perf improvement it can help, I'm wondering if it's worth the efforts.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

locally verified that the scenario mentioned in the comment would not break SCOS with current change in snowpark

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any metrics on the improvements from this change? I agree with Adam that this change seems risky for little benefit.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about why the no-aggregation function in SCOS avoids the issue after we remove limit 1

However, if the underlying logic is overly complex or confusing, let's skip it. Without a clear understanding, we won't have high confidence against regressions.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a few re-test, it turns out that this change is also not safe for scos, so we will have to add a lot of logic to make it safe for scos like the first version of this change.
At this point, I think it does not worth it to have this change just to remove limit 1.

return project_statement(aggregate_exprs, child, child_uuid=child_uuid) + (
limit_expression(1)
if not grouping_exprs
else (
NEW_LINE
+ GROUP_BY
+ NEW_LINE
+ TAB
+ (COMMA + NEW_LINE + TAB).join(grouping_exprs)
if not grouping_exprs:
return project_statement(aggregate_exprs, child, child_uuid=child_uuid) + (
EMPTY_STRING
if context._is_snowpark_connect_compatible_mode
else limit_expression(1)
)
return project_statement(aggregate_exprs, child, child_uuid=child_uuid) + (
NEW_LINE
+ GROUP_BY
+ NEW_LINE
+ TAB
+ (COMMA + NEW_LINE + TAB).join(grouping_exprs)
)


Expand Down
26 changes: 26 additions & 0 deletions tests/integ/test_df_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
upper,
grouping,
grouping_id,
lit,
)
from snowflake.snowpark.mock._snowflake_data_type import ColumnEmulator, ColumnType
from snowflake.snowpark.types import DoubleType, IntegerType, StructType, StructField
Expand Down Expand Up @@ -726,6 +727,31 @@ def test_agg_no_grouping_exprs_limit_snowpark_connect_compatible(session):
Utils.check_answer(result, [Row(10)])


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="local testing query does not have limit 1",
)
def test_global_aggregate_limit_compat_mode_snowpark_connect_compatible(session):
df = session.create_dataframe([[1, 2], [3, 4], [1, 4]], schema=["A", "B"])

with mock.patch(
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", False
):
non_compat_query = df.agg(sum_(col("a"))).queries["queries"][-1].upper()
assert "LIMIT 1" in non_compat_query

with mock.patch(
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True
):
sum_query = df.agg(sum_(col("a"))).queries["queries"][-1].upper()
count_query = df.agg(count(col("a"))).queries["queries"][-1].upper()
non_agg_expr_query = df.agg(lit(1)).queries["queries"][-1].upper()

assert "LIMIT 1" not in sum_query
assert "LIMIT 1" not in count_query
assert "LIMIT 1" not in non_agg_expr_query


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="HAVING and ORDER BY append are not supported in local testing mode",
Expand Down
15 changes: 12 additions & 3 deletions tests/unit/test_analyzer_util_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,18 @@ def test_sample_by_statement_formatting(mock_random_name):


def test_aggregate_statement_formatting():
assert aggregate_statement([], ["COUNT(*) as cnt"], "my_table") == (
" SELECT \n" " COUNT(*) as cnt\n" " FROM (\n" "my_table\n" ") LIMIT 1"
)
with mock.patch(
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", False
):
assert aggregate_statement([], ["COUNT(*) as cnt"], "my_table") == (
" SELECT \n" " COUNT(*) as cnt\n" " FROM (\n" "my_table\n" ") LIMIT 1"
)
with mock.patch(
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True
):
assert aggregate_statement([], ["COUNT(*) as cnt"], "my_table") == (
" SELECT \n" " COUNT(*) as cnt\n" " FROM (\n" "my_table\n" ")"
)

assert aggregate_statement(["dept", "title"], ["COUNT(*) as cnt"], "my_table") == (
" SELECT \n"
Expand Down
Loading