Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### Bug Fixes

- Fixed a bug that `DataFrame.limit()` fail if there is parameter binding in the executed SQL.
- Added an experimental fix for a bug in schema query generation that could cause invalid sql to be genrated when using nested structured types.

#### New Features

Expand Down
23 changes: 17 additions & 6 deletions src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from decimal import Decimal
from typing import Any

import snowflake.snowpark.context as context
import snowflake.snowpark._internal.analyzer.analyzer_utils as analyzer_utils
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
from snowflake.snowpark._internal.utils import (
Expand Down Expand Up @@ -518,18 +519,26 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
if isinstance(data_type, ArrayType):
if data_type.structured:
assert data_type.element_type is not None
element = schema_expression(data_type.element_type, data_type.contains_null)
if context._enable_fix_2360274:
element = "NULL"
else:
element = schema_expression(
data_type.element_type, data_type.contains_null
)
return f"to_array({element}) :: {convert_sp_to_sf_type(data_type)}"
return "to_array(0)"
if isinstance(data_type, MapType):
if data_type.structured:
assert data_type.key_type is not None and data_type.value_type is not None
# Key values can never be null
key = schema_expression(data_type.key_type, False)
# Value nullability is variable. Defaults to True
value = schema_expression(
data_type.value_type, data_type.value_contains_null
)
if context._enable_fix_2360274:
value = "NULL"
else:
# Value nullability is variable. Defaults to True
value = schema_expression(
data_type.value_type, data_type.value_contains_null
)
return f"object_construct_keep_null({key}, {value}) :: {convert_sp_to_sf_type(data_type)}"
return "to_object(parse_json('0'))"
if isinstance(data_type, StructType):
Expand All @@ -539,7 +548,9 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
# Even if nulls are allowed the cast will fail due to schema mismatch when passed a null field.
schema_strings += [
f"'{field.name}'",
schema_expression(field.datatype, is_nullable=False),
"NULL"
if context._enable_fix_2360274
else schema_expression(field.datatype, is_nullable=False),
]
return f"object_construct_keep_null({', '.join(schema_strings)}) :: {convert_sp_to_sf_type(data_type)}"
return "to_object(parse_json('{}'))"
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
# This is an internal-only global flag, used to determine whether to enable query line tracking for tracing sql compilation errors.
_enable_trace_sql_errors_to_dataframe = False

# SNOW-2362050: Enable this fix by default.
# Global flag for fix 2360274. When enabled schema queries will use NULL as a place holder for any values inside structured objects
_enable_fix_2360274 = False
Comment thread
sfc-gh-jrose marked this conversation as resolved.


def configure_development_features(
*,
Expand Down
108 changes: 108 additions & 0 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import logging
import pytest
from unittest import mock

import snowflake.snowpark.context as context
from snowflake.connector.options import installed_pandas
Expand Down Expand Up @@ -1763,3 +1764,110 @@ def test_lob_collect_max_size(session, server_side_max_string, type_string, data
)
assert df.schema == StructType([StructField("DATA", datatype, nullable=False)])
assert len(df.collect()[0][0]) >= server_side_max_string - 16


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="Structured types are not supported in Local Testing",
)
@pytest.mark.parametrize("fix_enabled", [True, False])
def test_snow_2360274_repro(
structured_type_session, structured_type_support, fix_enabled
):
if not structured_type_support:
pytest.skip("Test requires structured type support.")

agg_table_name = f"snowpark_2360274_repro_agg_{uuid.uuid4().hex[:5]}".upper()

nested_field_name = (
"value" if context._should_use_structured_type_semantics() else '"value"'
)
expected_schema = StructType(
[
StructField("ID", LongType(), nullable=False),
StructField(
"VALS_ARR",
ArrayType(
StructType(
[StructField(nested_field_name, StringType(10), nullable=True)]
)
),
nullable=True,
),
StructField(
"VALS_MAP",
MapType(StringType(10), StringType(10)),
nullable=True,
),
StructField(
"VALS_OBJ",
StructType(
[StructField(nested_field_name, StringType(10), nullable=True)]
),
nullable=True,
),
StructField("TAG", StringType(2), nullable=False),
]
)

def inner():
structured_type_session.sql(
f"""
CREATE
OR REPLACE TABLE {agg_table_name} (
ID INT NOT NULL,
VALS_ARR ARRAY(OBJECT({nested_field_name} STRING(10))) NOT NULL,
VALS_MAP MAP(STRING(10), STRING(10)) NOT NULL,
VALS_OBJ OBJECT({nested_field_name} STRING(10)) NOT NULL
) AS WITH SRC(ID, VALUE) AS (
SELECT
$1,
$2
FROM
VALUES
(1, 'A'),
(1, 'B'),
(2, 'A')
)
SELECT
ID,
CAST(
ARRAY_AGG(OBJECT_CONSTRUCT('value', VALUE)) AS ARRAY(OBJECT({nested_field_name} STRING))
) AS VALS_ARR,
CAST(
OBJECT_CONSTRUCT('value', VALUE) AS MAP(STRING, STRING)
) AS VALS_MAP,
CAST(
OBJECT_CONSTRUCT('value', VALUE) AS OBJECT({nested_field_name} STRING)
) AS VALS_OBJ,
FROM
SRC
GROUP BY
ID, VALS_MAP, VALS_OBJ"""
).collect()

agged = structured_type_session.table(agg_table_name)

reference = structured_type_session.sql(
"""
SELECT $1 AS ID, $2 AS TAG FROM VALUES (1, 'AB'), (2, 'B')
"""
)

joined = agged.join(reference, on=agged.id == reference.id, how="inner").select(
agged.id.alias("ID"), "VALS_ARR", "VALS_MAP", "VALS_OBJ", "TAG"
)
Utils.is_schema_same(joined.schema, expected_schema, case_sensitive=False)

try:
with mock.patch.object(context, "_enable_fix_2360274", fix_enabled):
if fix_enabled:
inner()
else:
with pytest.raises(
SnowparkSQLException,
match="Unsupported data type 'STRUCTURED_OBJECT'",
):
inner()
finally:
Utils.drop_table(structured_type_session, agg_table_name)
1 change: 1 addition & 0 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2291,6 +2291,7 @@ def artifact_repo_test(_):
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+"
)
@pytest.mark.skip("SNOW-2362946: Skip until root cause is found.")
def test_sproc_artifact_repository_from_file(session, tmpdir):
source = dedent(
"""
Expand Down