Skip to content

Commit 2d1cbf1

Browse files
committed
more defensive code
1 parent b36f135 commit 2d1cbf1

2 files changed

Lines changed: 96 additions & 45 deletions

File tree

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 76 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@
5151

5252
from snowflake.snowpark._internal.analyzer import analyzer_utils
5353
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
54+
quote_name_without_upper_casing,
5455
result_scan_statement,
5556
schema_value_statement,
57+
unquote_if_quoted,
5658
)
5759
from snowflake.snowpark._internal.analyzer.binary_expression import And
5860
from snowflake.snowpark._internal.analyzer.expression import (
@@ -85,8 +87,10 @@
8587
has_invalid_projection_merge_functions,
8688
)
8789
from snowflake.snowpark._internal.utils import (
88-
is_sql_select_statement,
90+
ALREADY_QUOTED,
8991
ExprAliasUpdateDict,
92+
is_sql_select_statement,
93+
quote_name,
9094
)
9195
import snowflake.snowpark.context as context
9296

@@ -1591,51 +1595,71 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
15911595

15921596
# When describe reduction is on and the inner select already has resolved
15931597
# attributes, infer new.attributes for this outer select by reusing datatype and
1594-
# nullable from the subquery: (1) index attributes by name, (2) walk
1595-
# new.projection, (3) only handle plain columns or Alias(column) — anything
1596-
# else aborts without setting partial attributes, (4) map each case to an
1597-
# Attribute named for the projected column, (5) assign only if every output
1598-
# column was inferred (length matches projection).
1598+
# nullable from the subquery: (0) skip if parent column names collide, (1) index
1599+
# attributes by normalized name, (2) walk new.projection, (3) only handle plain
1600+
# columns or Alias(column), (4) resolve source via quoted-identifier-aware lookup,
1601+
# (5) assign only if every output column was inferred (length matches projection).
15991602
if self._session.reduce_describe_query_enabled and self.attributes is not None:
1600-
# subquery lookup by name
1601-
attributes_by_name = {attr.name: attr for attr in self.attributes}
1602-
inferred_attributes: List[Attribute] = []
1603-
assert new.projection is not None
1604-
# infer from each projected expression
1605-
for expr in new.projection:
1606-
source_column_name = None
1607-
projected_column_name = None
1608-
if isinstance(expr, (Attribute, UnresolvedAttribute)):
1609-
# identity projection: output name equals input column
1610-
source_column_name = expr.name
1611-
projected_column_name = expr.name
1612-
elif isinstance(expr, Alias) and isinstance(
1613-
expr.child, (Attribute, UnresolvedAttribute)
1614-
):
1615-
# rename: source column from child, output name from alias
1616-
source_column_name = expr.child.name
1617-
projected_column_name = expr.name
1618-
else:
1619-
# non-simple expression: cannot infer types safely
1620-
inferred_attributes = []
1621-
break
1622-
1623-
source_attr = attributes_by_name.get(source_column_name)
1624-
if source_attr is None or projected_column_name is None:
1625-
# missing subquery column for this projection — abort
1626-
inferred_attributes = []
1627-
break
1628-
1629-
# projected name with subquery type and nullability
1630-
inferred_attributes.append(
1631-
Attribute(
1632-
projected_column_name,
1633-
source_attr.datatype,
1634-
source_attr.nullable,
1635-
)
1636-
)
1637-
if len(inferred_attributes) == len(new.projection):
1638-
# only commit when every column was inferred
1603+
parent_attributes = self.attributes
1604+
projection = new.projection
1605+
inferred_attributes: Optional[List[Attribute]] = None
1606+
# Skip: no projection to walk (do not assert; leave new.attributes unchanged).
1607+
if projection is not None:
1608+
# Skip: duplicate output names on the parent — dict/lookup would be ambiguous.
1609+
if len(parent_attributes) == len({a.name for a in parent_attributes}):
1610+
attributes_by_normalized: Dict[str, Attribute] = {}
1611+
collision = False
1612+
for attr in parent_attributes:
1613+
key = _normalized_snowflake_identifier_key(attr.name)
1614+
existing = attributes_by_normalized.get(key)
1615+
# Skip: two parent columns normalize to the same key.
1616+
if existing is not None and existing is not attr:
1617+
collision = True
1618+
break
1619+
attributes_by_normalized[key] = attr
1620+
if not collision:
1621+
inferred_attributes = []
1622+
for expr in projection:
1623+
source_column_name: Optional[str] = None
1624+
projected_column_name: Optional[str] = None
1625+
if isinstance(expr, (Attribute, UnresolvedAttribute)):
1626+
source_column_name = expr.name
1627+
projected_column_name = expr.name
1628+
elif isinstance(expr, Alias) and isinstance(
1629+
expr.child, (Attribute, UnresolvedAttribute)
1630+
):
1631+
source_column_name = expr.child.name
1632+
projected_column_name = expr.name
1633+
else:
1634+
# Skip: not a plain column or Alias(Attribute|UnresolvedAttribute).
1635+
inferred_attributes = []
1636+
break
1637+
1638+
if (
1639+
source_column_name is None
1640+
or projected_column_name is None
1641+
):
1642+
# Skip: missing projected output name.
1643+
inferred_attributes = []
1644+
break
1645+
source_attr = attributes_by_normalized.get(
1646+
_normalized_snowflake_identifier_key(source_column_name)
1647+
)
1648+
# Skip: no parent column for this source name.
1649+
if source_attr is None:
1650+
inferred_attributes = []
1651+
break
1652+
inferred_attributes.append(
1653+
Attribute(
1654+
projected_column_name,
1655+
source_attr.datatype,
1656+
source_attr.nullable,
1657+
)
1658+
)
1659+
if len(inferred_attributes) != len(projection):
1660+
# Skip: incomplete inference (includes defensive mismatch).
1661+
inferred_attributes = None
1662+
if inferred_attributes is not None:
16391663
new.attributes = inferred_attributes
16401664

16411665
new.flatten_disabled = disable_next_level_flatten
@@ -2136,6 +2160,13 @@ class DeriveColumnDependencyError(Exception):
21362160
"""When deriving column dependencies from the subquery."""
21372161

21382162

2163+
def _normalized_snowflake_identifier_key(name: str) -> str:
2164+
"""Canonical quoted key: delimited identifiers preserve case; unquoted follow Snowflake uppercasing."""
2165+
if ALREADY_QUOTED.match(name):
2166+
return quote_name_without_upper_casing(unquote_if_quoted(name))
2167+
return quote_name(name)
2168+
2169+
21392170
def parse_column_name(
21402171
column: Expression,
21412172
analyzer: "Analyzer",

tests/integ/test_reduce_describe_query.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,26 @@ def test_chained_simple_renames_infer_from_previous_metadata(session):
563563
_ = df2._plan.attributes
564564

565565

566+
def test_quoted_case_sensitive_sql_column_metadata_inference(session):
567+
"""Delimited identifier from session.sql: chained select infers metadata without DESCRIBE."""
568+
df = session.sql('SELECT 1 AS "MixedCase"')
569+
with SqlCounter(query_count=0, describe_count=1, strict=False):
570+
_ = df.schema
571+
572+
df2 = df.select(col('"MixedCase"'))
573+
if session.reduce_describe_query_enabled:
574+
assert df2._plan._metadata.attributes is not None
575+
assert len(df2._plan._metadata.attributes) == 1
576+
assert df2._plan._metadata.attributes[0].name == '"MixedCase"'
577+
578+
expected_describe = 0 if session.reduce_describe_query_enabled else 1
579+
with SqlCounter(query_count=0, describe_count=expected_describe):
580+
attrs = df2._plan.attributes
581+
assert attrs is not None
582+
assert len(attrs) == 1
583+
assert attrs[0].name == '"MixedCase"'
584+
585+
566586
def test_non_simple_projection_skips_metadata_inference(session):
567587
"""Expressions other than plain column or simple alias(column) do not infer attributes."""
568588
df = session.create_dataframe([[1, 2]], schema=["a", "b"])

0 commit comments

Comments
 (0)