Skip to content

Commit 6037cd7

Browse files
authored
SNOW-2370108: snowpark fails when group by agg called twice on the same df (#3819)
1 parent bcb81a9 commit 6037cd7

4 files changed

Lines changed: 46 additions & 17 deletions

File tree

.github/workflows/daily_precommit.yml

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ jobs:
655655
.tox/coverage.xml
656656
657657
test-enable-fix-join-alias:
658-
name: Test Fixing Join Alias py-${{ matrix.os }}-${{ matrix.python-version }}
658+
name: Test Fixing Join Alias py-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}
659659
needs: build
660660
runs-on: ${{ matrix.os }}
661661
strategy:
@@ -707,7 +707,7 @@ jobs:
707707
- name: Install tox
708708
run: uv pip install tox --system
709709
# we only run doctest on macos
710-
- if: ${{ matrix.os == 'macos-latest' && matrix.python-version != '3.12'}}
710+
- if: ${{ matrix.os == 'macos-latest'}}
711711
name: Run doctests
712712
run: python -m tox -e "py${PYTHON_VERSION}-doctest-notudf-ci"
713713
env:
@@ -719,7 +719,7 @@ jobs:
719719
# For example, see https://github.com/snowflakedb/snowpark-python/pull/681
720720
shell: bash
721721
# do not run other tests for macos
722-
- if: ${{ matrix.os != 'macos-latest' && matrix.python-version != '3.12' }}
722+
- if: ${{ matrix.os != 'macos-latest'}}
723723
name: Run tests (excluding doctests)
724724
run: python -m tox -e "py${PYTHON_VERSION/\./}-notdoctest-ci"
725725
env:
@@ -730,18 +730,6 @@ jobs:
730730
SNOWPARK_PYTHON_API_TEST_BUCKET_PATH: ${{ secrets.SNOWPARK_PYTHON_API_TEST_BUCKET_PATH }}
731731
SNOWPARK_PYTHON_API_S3_STORAGE_INTEGRATION: ${{ vars.SNOWPARK_PYTHON_API_S3_STORAGE_INTEGRATION }}
732732
shell: bash
733-
- if: ${{ matrix.python-version == '3.12' }}
734-
name: Run tests (excluding doctests and udf tests)
735-
run: python -m tox -e "py${PYTHON_VERSION/\./}-notudfdoctest-ci"
736-
env:
737-
PYTHON_VERSION: ${{ matrix.python-version }}
738-
cloud_provider: ${{ matrix.cloud-provider }}
739-
PYTEST_ADDOPTS: --color=yes --tb=short --join_alias_fix
740-
TOX_PARALLEL_NO_SPINNER: 1
741-
SNOWFLAKE_IS_PYTHON_RUNTIME_TEST: 1
742-
SNOWPARK_PYTHON_API_TEST_BUCKET_PATH: ${{ secrets.SNOWPARK_PYTHON_API_TEST_BUCKET_PATH }}
743-
SNOWPARK_PYTHON_API_S3_STORAGE_INTEGRATION: ${{ vars.SNOWPARK_PYTHON_API_S3_STORAGE_INTEGRATION }}
744-
shell: bash
745733
- name: Combine coverages
746734
run: python -m tox -e coverage --skip-missing-interpreters false
747735
shell: bash

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,7 @@ def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
808808

809809
def add_aliases(self, to_add: Dict) -> None:
810810
if self.session._join_alias_fix:
811+
self.expr_to_alias = self.expr_to_alias.copy()
811812
self.expr_to_alias.update(to_add)
812813
else:
813814
self.expr_to_alias = {**self.expr_to_alias, **to_add}

tests/integ/scala/test_dataframe_join_suite.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
SnowparkSQLException,
1818
SnowparkSQLInvalidIdException,
1919
)
20-
from snowflake.snowpark.functions import coalesce, col, count, is_null, lit
20+
from snowflake.snowpark.functions import (
21+
coalesce,
22+
col,
23+
count,
24+
is_null,
25+
lit,
26+
sum as sp_sum,
27+
)
2128
from snowflake.snowpark.types import (
2229
IntegerType,
2330
StringType,
@@ -1626,3 +1633,37 @@ def test_dataframe_join_and_select_same_column_name_from_one_df(session):
16261633
assert df1.join(df2,).select(
16271634
df2.col("a")
16281635
).collect() == [Row(2)]
1636+
1637+
1638+
@pytest.mark.skipif(
1639+
"config.getoption('local_testing_mode', default=False)",
1640+
reason="SNOW-1373887: The join alias fix is not supported in Local Testing",
1641+
)
1642+
def test_dataframe_alias_map_unmodified(session):
1643+
origin = session._join_alias_fix
1644+
try:
1645+
session._join_alias_fix = True
1646+
df = session.create_dataframe([None], ["__DUMMY"])
1647+
1648+
cols = [lit("James"), lit(3000)]
1649+
df = (
1650+
df.with_columns(["name", "salary"], cols)
1651+
.select(*cols)
1652+
.toDF(*["name", "salary"])
1653+
)
1654+
1655+
def aggregate(input):
1656+
source_expr_to_alias = input._plan.expr_to_alias
1657+
ret_df = input.group_by(input.col("name").alias("new_name")).agg(
1658+
sp_sum(input.col("salary"))
1659+
)
1660+
assert (
1661+
source_expr_to_alias == input._plan.expr_to_alias
1662+
) # ensure the original df alias map is not changed
1663+
return ret_df
1664+
1665+
Utils.check_answer(aggregate(df), [Row("James", 3000)])
1666+
# execute twice to make sure no side effect
1667+
Utils.check_answer(aggregate(df), [Row("James", 3000)])
1668+
finally:
1669+
session._join_alias_fix = origin

tests/integ/test_cte.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,6 @@ def test_sql_simplifier(session):
809809
describe_count=0,
810810
union_count=0,
811811
join_count=2,
812-
describe_count_for_optimized=1 if session._join_alias_fix else None,
813812
)
814813
with SqlCounter(query_count=0, describe_count=0):
815814
# When adding a lsuffix, expr alias map will be updated, so df2 and df3 are considered

0 commit comments

Comments
 (0)