Skip to content

Commit 37609b6

Browse files
committed
Bugfix join bug due to dataframe alias
1 parent c9d4a24 commit 37609b6

4 files changed

Lines changed: 63 additions & 6 deletions

File tree

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,8 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
912912

913913
for c in logical_plan.children: # post-order traversal of the tree
914914
resolved = self.resolve(c)
915-
df_aliased_col_name_to_real_col_name.update(resolved.df_aliased_col_name_to_real_col_name) # type: ignore
915+
for alias, dict_ in resolved.df_aliased_col_name_to_real_col_name.items():
916+
df_aliased_col_name_to_real_col_name[alias].update(dict_)
916917
resolved_children[c] = resolved
917918

918919
if isinstance(logical_plan, Selectable):
@@ -944,9 +945,8 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
944945
res = self.do_resolve_with_resolved_children(
945946
logical_plan, resolved_children, df_aliased_col_name_to_real_col_name
946947
)
947-
res.df_aliased_col_name_to_real_col_name.update(
948-
df_aliased_col_name_to_real_col_name
949-
)
948+
for alias, dict_ in df_aliased_col_name_to_real_col_name.items():
949+
res.df_aliased_col_name_to_real_col_name[alias].update(dict_)
950950
return res
951951

952952
def do_resolve_with_resolved_children(

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ def __init__(
877877
self._projection_in_str = None
878878
self._query_params = None
879879
self.expr_to_alias.update(self.from_.expr_to_alias)
880-
self.df_aliased_col_name_to_real_col_name.update(
880+
self.df_aliased_col_name_to_real_col_name = deepcopy(
881881
self.from_.df_aliased_col_name_to_real_col_name
882882
)
883883
self.api_calls = (

src/snowflake/snowpark/_internal/compiler/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,10 @@ def update_resolvable_node(
252252
# df_aliased_col_name_to_real_col_name is updated at the frontend api
253253
# layer when alias is called, not produced during code generation. Should
254254
# always retain the original value of the map.
255-
node.df_aliased_col_name_to_real_col_name.update(
255+
node.df_aliased_col_name_to_real_col_name = copy.deepcopy(
256256
node.from_.df_aliased_col_name_to_real_col_name
257257
)
258+
258259
# projection_in_str for SelectStatement runs a analyzer.analyze() which
259260
# needs the correct expr_to_alias map setup. This map is setup during
260261
# snowflake plan generation and cached for later use. Calling snowflake_plan

tests/integ/test_cte.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
when_matched,
2727
to_timestamp,
2828
)
29+
from snowflake.snowpark.types import (
30+
StructType,
31+
StructField,
32+
IntegerType,
33+
StringType,
34+
TimestampType,
35+
)
2936
from tests.integ.scala.test_dataframe_reader_suite import get_reader
3037
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker
3138
from tests.utils import IS_IN_STORED_PROC, IS_IN_STORED_PROC_LOCALFS, TestFiles, Utils
@@ -273,6 +280,55 @@ def test_join_with_alias_dataframe(session):
273280
assert last_query.count(WITH) == 1
274281

275282

283+
def test_join_with_alias_dataframe_2(session):
284+
# Reproduced from issue SNOW-2257191
285+
schema1 = StructType(
286+
[
287+
StructField("DST_Year", IntegerType(), True),
288+
StructField("DST_Start", TimestampType(), True),
289+
StructField("DST_End", TimestampType(), True),
290+
]
291+
)
292+
293+
schema2 = StructType(
294+
[
295+
StructField("MATTRANSID", StringType(), True),
296+
StructField("LOADSTARTTIME", TimestampType(), True),
297+
StructField("LOADENDTIME", TimestampType(), True),
298+
StructField("DUMPENDTIME", TimestampType(), True),
299+
StructField("__CURRENT", StringType(), True),
300+
StructField("__DELETED", StringType(), True),
301+
]
302+
)
303+
304+
schema3 = StructType(
305+
[
306+
StructField("MATTRANSID", StringType(), True),
307+
StructField("DUMPENDTIME", TimestampType(), True),
308+
StructField("LOADENDTIME", TimestampType(), True),
309+
StructField("__CURRENT", StringType(), True),
310+
StructField("__DELETED", StringType(), True),
311+
]
312+
)
313+
314+
df1 = session.create_dataframe([], schema=schema1).cache_result()
315+
df2 = session.create_dataframe([], schema=schema2).cache_result()
316+
df3 = session.create_dataframe([], schema=schema3).cache_result()
317+
318+
df4 = df2.alias("d2").join(
319+
df1, col("d2", "LoadStartTime").between(df1.DST_Start, df1.DST_End), "left"
320+
)
321+
322+
df5 = df3.alias("d3").join(
323+
df1, col("d3", "LoadEndTime").between(df1.DST_Start, df1.DST_End), "left"
324+
)
325+
326+
df6 = df5.join(df4, (df5.MatTransId == df4.MatTransId), "left")
327+
328+
# Assert that the generated sql compiles
329+
df6.collect()
330+
331+
276332
def test_join_with_set_operation(session):
277333
df1 = session.create_dataframe([[1, 2, 3], [4, 5, 6]], "a: int, b: int, c: int")
278334
df2 = session.create_dataframe([[1, 1], [4, 5]], "a: int, b: int")

0 commit comments

Comments
 (0)