|
2 | 2 | # |
3 | 3 | # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. |
4 | 4 | # |
| 5 | +import copy |
5 | 6 | import decimal |
6 | 7 | import math |
7 | 8 | from unittest import mock |
|
32 | 33 | ) |
33 | 34 | from snowflake.snowpark.mock._snowflake_data_type import ColumnEmulator, ColumnType |
34 | 35 | from snowflake.snowpark.types import DoubleType, IntegerType, StructType, StructField |
| 36 | +from snowflake.snowpark._internal.utils import is_ast_enabled |
35 | 37 | from tests.utils import Utils |
36 | 38 |
|
37 | 39 |
|
@@ -1370,3 +1372,68 @@ def test_group_by_exclude_grouping_columns(session): |
1370 | 1372 | ) |
1371 | 1373 | assert len(result_builtin_exclude[0]) == 1 # only sum |
1372 | 1374 | Utils.check_answer(result_builtin_exclude, [Row(6), Row(15)]) |
| 1375 | + |
| 1376 | + |
| 1377 | +@pytest.mark.skipif( |
| 1378 | + "config.getoption('local_testing_mode', default=False)", |
| 1379 | + reason="ORDER BY append and limit append are not supported in local testing mode", |
| 1380 | +) |
| 1381 | +def test_copy_preserves_agg_state(session): |
| 1382 | + """copy.copy() and _copy_without_ast() must preserve post-aggregate state so |
| 1383 | + that .limit() and .sort() on the copy go through _build_post_agg_df and |
| 1384 | + generate correct SQL (ORDER BY inside the aggregate subquery, not lost on |
| 1385 | + the outer wrapper).""" |
| 1386 | + if is_ast_enabled(): |
| 1387 | + pytest.skip( |
| 1388 | + "_copy_without_ast() leaves _ast_id=None; calling limit() on the copy " |
| 1389 | + "crashes in AST mode because publicapi injects _emit_ast=True via the " |
| 1390 | + "global is_ast_enabled() which bypasses the Session.ast_enabled mock." |
| 1391 | + ) |
| 1392 | + # Disable AST: copy.copy(df).limit() triggers debug_check_missing_ast because |
| 1393 | + # the copy carries the source's API usage with no corresponding AST entries. |
| 1394 | + with mock.patch( |
| 1395 | + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True |
| 1396 | + ): |
| 1397 | + df = session.create_dataframe( |
| 1398 | + [ |
| 1399 | + ("a", 3), |
| 1400 | + ("b", 1), |
| 1401 | + ("a", 1), |
| 1402 | + ("b", 2), |
| 1403 | + ("c", 10), |
| 1404 | + ], |
| 1405 | + ["k", "v"], |
| 1406 | + ) |
| 1407 | + agg_sorted = ( |
| 1408 | + df.group_by("k") |
| 1409 | + .agg(sum_("v").alias("total")) |
| 1410 | + .filter(col("total") > 1) |
| 1411 | + .sort(col("total").desc()) |
| 1412 | + ) |
| 1413 | + |
| 1414 | + for copied in (copy.copy(agg_sorted), agg_sorted._copy_without_ast()): |
| 1415 | + # Internal state must be carried over so _build_post_agg_df fires correctly |
| 1416 | + assert ( |
| 1417 | + copied._ops_after_agg |
| 1418 | + and copied._ops_after_agg == agg_sorted._ops_after_agg |
| 1419 | + ) |
| 1420 | + assert ( |
| 1421 | + copied._agg_base_plan |
| 1422 | + and copied._agg_base_plan == agg_sorted._agg_base_plan |
| 1423 | + ) |
| 1424 | + assert ( |
| 1425 | + copied._agg_base_select_statement |
| 1426 | + and copied._agg_base_select_statement |
| 1427 | + is agg_sorted._agg_base_select_statement |
| 1428 | + ) |
| 1429 | + assert ( |
| 1430 | + copied._pending_order_bys |
| 1431 | + and copied._pending_order_bys == agg_sorted._pending_order_bys |
| 1432 | + ) |
| 1433 | + assert ( |
| 1434 | + copied._pending_havings |
| 1435 | + and copied._pending_havings == agg_sorted._pending_havings |
| 1436 | + ) |
| 1437 | + |
| 1438 | + # Observable result: ORDER BY must be respected under LIMIT |
| 1439 | + Utils.check_answer(copied.limit(2), [Row("c", 10), Row("a", 4)]) |
0 commit comments