Skip to content

Commit 4afc47d

Browse files
authored
SNOW-3590670: fix that agg stats are not copied in df.copy (#4253)
1 parent 8a140ef commit 4afc47d

3 files changed

Lines changed: 84 additions & 2 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#### Bug Fixes
1212

13+
- Fixed a bug where copying a `DataFrame` via `copy.copy()` lost post-aggregate state, causing subsequent `.limit()` or `.sort()` to generate incorrect SQL.
1314
- Fixed a bug where calling `DataFrame.alias()` twice on the same DataFrame (e.g. for a self-join) caused both aliases to share the same internal column-mapping dictionary. This made `col("R", "col")` resolve to the same column as `col("L", "col")`, producing incorrect join conditions and filter expressions.
1415
- Fixed a bug where `cloudpickle` could not be resolved when registering a Python stored procedure or UDF with `runtime_version='3.13'`.
1516

src/snowflake/snowpark/dataframe.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,9 +1047,21 @@ def _copy_plan(self) -> LogicalPlan:
10471047
else:
10481048
return copy.copy(self._plan)
10491049

1050+
def _copy_agg_state(self, target: "DataFrame") -> None:
1051+
"""Copy post-aggregate state to target so operations like .limit() and
1052+
.sort() on the copy go through _build_post_agg_df and generate correct
1053+
SQL (ORDER BY inside the aggregate subquery, not on the outer query)."""
1054+
target._ops_after_agg = self._ops_after_agg
1055+
target._agg_base_plan = self._agg_base_plan
1056+
target._agg_base_select_statement = self._agg_base_select_statement
1057+
target._pending_havings = self._pending_havings
1058+
target._pending_order_bys = self._pending_order_bys
1059+
10501060
def _copy_without_ast(self) -> "DataFrame":
10511061
"""Returns a shallow copy of the DataFrame without AST generation."""
1052-
return DataFrame(self._session, self._copy_plan(), _emit_ast=False)
1062+
result = DataFrame(self._session, self._copy_plan(), _emit_ast=False)
1063+
self._copy_agg_state(result)
1064+
return result
10531065

10541066
def __copy__(self) -> "DataFrame":
10551067
"""Implements shallow copy protocol for copy.copy(...)."""
@@ -1058,12 +1070,14 @@ def __copy__(self) -> "DataFrame":
10581070
stmt = self._session._ast_batch.bind()
10591071
with_src_position(stmt.expr.dataframe_ref, stmt)
10601072
self._set_ast_ref(stmt.expr)
1061-
return DataFrame(
1073+
result = DataFrame(
10621074
self._session,
10631075
self._copy_plan(),
10641076
_ast_stmt=stmt,
10651077
_emit_ast=self._session.ast_enabled,
10661078
)
1079+
self._copy_agg_state(result)
1080+
return result
10671081

10681082
if installed_pandas:
10691083
import pandas # pragma: no cover

tests/integ/test_df_aggregate.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
44
#
5+
import copy
56
import decimal
67
import math
78
from unittest import mock
@@ -32,6 +33,7 @@
3233
)
3334
from snowflake.snowpark.mock._snowflake_data_type import ColumnEmulator, ColumnType
3435
from snowflake.snowpark.types import DoubleType, IntegerType, StructType, StructField
36+
from snowflake.snowpark._internal.utils import is_ast_enabled
3537
from tests.utils import Utils
3638

3739

@@ -1370,3 +1372,68 @@ def test_group_by_exclude_grouping_columns(session):
13701372
)
13711373
assert len(result_builtin_exclude[0]) == 1 # only sum
13721374
Utils.check_answer(result_builtin_exclude, [Row(6), Row(15)])
1375+
1376+
1377+
@pytest.mark.skipif(
1378+
"config.getoption('local_testing_mode', default=False)",
1379+
reason="ORDER BY append and limit append are not supported in local testing mode",
1380+
)
1381+
def test_copy_preserves_agg_state(session):
1382+
"""copy.copy() and _copy_without_ast() must preserve post-aggregate state so
1383+
that .limit() and .sort() on the copy go through _build_post_agg_df and
1384+
generate correct SQL (ORDER BY inside the aggregate subquery, not lost on
1385+
the outer wrapper)."""
1386+
if is_ast_enabled():
1387+
pytest.skip(
1388+
"_copy_without_ast() leaves _ast_id=None; calling limit() on the copy "
1389+
"crashes in AST mode because publicapi injects _emit_ast=True via the "
1390+
"global is_ast_enabled() which bypasses the Session.ast_enabled mock."
1391+
)
1392+
# Disable AST: copy.copy(df).limit() triggers debug_check_missing_ast because
1393+
# the copy carries the source's API usage with no corresponding AST entries.
1394+
with mock.patch(
1395+
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True
1396+
):
1397+
df = session.create_dataframe(
1398+
[
1399+
("a", 3),
1400+
("b", 1),
1401+
("a", 1),
1402+
("b", 2),
1403+
("c", 10),
1404+
],
1405+
["k", "v"],
1406+
)
1407+
agg_sorted = (
1408+
df.group_by("k")
1409+
.agg(sum_("v").alias("total"))
1410+
.filter(col("total") > 1)
1411+
.sort(col("total").desc())
1412+
)
1413+
1414+
for copied in (copy.copy(agg_sorted), agg_sorted._copy_without_ast()):
1415+
# Internal state must be carried over so _build_post_agg_df fires correctly
1416+
assert (
1417+
copied._ops_after_agg
1418+
and copied._ops_after_agg == agg_sorted._ops_after_agg
1419+
)
1420+
assert (
1421+
copied._agg_base_plan
1422+
and copied._agg_base_plan == agg_sorted._agg_base_plan
1423+
)
1424+
assert (
1425+
copied._agg_base_select_statement
1426+
and copied._agg_base_select_statement
1427+
is agg_sorted._agg_base_select_statement
1428+
)
1429+
assert (
1430+
copied._pending_order_bys
1431+
and copied._pending_order_bys == agg_sorted._pending_order_bys
1432+
)
1433+
assert (
1434+
copied._pending_havings
1435+
and copied._pending_havings == agg_sorted._pending_havings
1436+
)
1437+
1438+
# Observable result: ORDER BY must be respected under LIMIT
1439+
Utils.check_answer(copied.limit(2), [Row("c", 10), Row("a", 4)])

0 commit comments

Comments
 (0)