diff --git a/CHANGELOG.md b/CHANGELOG.md index b0740db7fd..1d0e156f7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### Bug Fixes - Fixed a bug that `DataFrame.limit()` fail if there is parameter binding in the executed SQL. +- Added an experimental fix for a bug in schema query generation that could cause invalid sql to be genrated when using nested structured types. #### New Features diff --git a/src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py b/src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py index cb7f2a7944..b4604df0bb 100644 --- a/src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py +++ b/src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py @@ -12,6 +12,7 @@ from decimal import Decimal from typing import Any +import snowflake.snowpark.context as context import snowflake.snowpark._internal.analyzer.analyzer_utils as analyzer_utils from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type from snowflake.snowpark._internal.utils import ( @@ -518,7 +519,12 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str: if isinstance(data_type, ArrayType): if data_type.structured: assert data_type.element_type is not None - element = schema_expression(data_type.element_type, data_type.contains_null) + if context._enable_fix_2360274: + element = "NULL" + else: + element = schema_expression( + data_type.element_type, data_type.contains_null + ) return f"to_array({element}) :: {convert_sp_to_sf_type(data_type)}" return "to_array(0)" if isinstance(data_type, MapType): @@ -526,10 +532,13 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str: assert data_type.key_type is not None and data_type.value_type is not None # Key values can never be null key = schema_expression(data_type.key_type, False) - # Value nullability is variable. Defaults to True - value = schema_expression( - data_type.value_type, data_type.value_contains_null - ) + if context._enable_fix_2360274: + value = "NULL" + else: + # Value nullability is variable. Defaults to True + value = schema_expression( + data_type.value_type, data_type.value_contains_null + ) return f"object_construct_keep_null({key}, {value}) :: {convert_sp_to_sf_type(data_type)}" return "to_object(parse_json('0'))" if isinstance(data_type, StructType): @@ -539,7 +548,9 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str: # Even if nulls are allowed the cast will fail due to schema mismatch when passed a null field. schema_strings += [ f"'{field.name}'", - schema_expression(field.datatype, is_nullable=False), + "NULL" + if context._enable_fix_2360274 + else schema_expression(field.datatype, is_nullable=False), ] return f"object_construct_keep_null({', '.join(schema_strings)}) :: {convert_sp_to_sf_type(data_type)}" return "to_object(parse_json('{}'))" diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index d3ee14c97e..86e92b6aa4 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -39,6 +39,10 @@ # This is an internal-only global flag, used to determine whether to enable query line tracking for tracing sql compilation errors. _enable_trace_sql_errors_to_dataframe = False +# SNOW-2362050: Enable this fix by default. +# Global flag for fix 2360274. When enabled schema queries will use NULL as a place holder for any values inside structured objects +_enable_fix_2360274 = False + def configure_development_features( *, diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index ccb036f85f..a16fc888a0 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -9,6 +9,7 @@ import logging import pytest +from unittest import mock import snowflake.snowpark.context as context from snowflake.connector.options import installed_pandas @@ -1763,3 +1764,110 @@ def test_lob_collect_max_size(session, server_side_max_string, type_string, data ) assert df.schema == StructType([StructField("DATA", datatype, nullable=False)]) assert len(df.collect()[0][0]) >= server_side_max_string - 16 + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="Structured types are not supported in Local Testing", +) +@pytest.mark.parametrize("fix_enabled", [True, False]) +def test_snow_2360274_repro( + structured_type_session, structured_type_support, fix_enabled +): + if not structured_type_support: + pytest.skip("Test requires structured type support.") + + agg_table_name = f"snowpark_2360274_repro_agg_{uuid.uuid4().hex[:5]}".upper() + + nested_field_name = ( + "value" if context._should_use_structured_type_semantics() else '"value"' + ) + expected_schema = StructType( + [ + StructField("ID", LongType(), nullable=False), + StructField( + "VALS_ARR", + ArrayType( + StructType( + [StructField(nested_field_name, StringType(10), nullable=True)] + ) + ), + nullable=True, + ), + StructField( + "VALS_MAP", + MapType(StringType(10), StringType(10)), + nullable=True, + ), + StructField( + "VALS_OBJ", + StructType( + [StructField(nested_field_name, StringType(10), nullable=True)] + ), + nullable=True, + ), + StructField("TAG", StringType(2), nullable=False), + ] + ) + + def inner(): + structured_type_session.sql( + f""" + CREATE + OR REPLACE TABLE {agg_table_name} ( + ID INT NOT NULL, + VALS_ARR ARRAY(OBJECT({nested_field_name} STRING(10))) NOT NULL, + VALS_MAP MAP(STRING(10), STRING(10)) NOT NULL, + VALS_OBJ OBJECT({nested_field_name} STRING(10)) NOT NULL + ) AS WITH SRC(ID, VALUE) AS ( + SELECT + $1, + $2 + FROM + VALUES + (1, 'A'), + (1, 'B'), + (2, 'A') + ) + SELECT + ID, + CAST( + ARRAY_AGG(OBJECT_CONSTRUCT('value', VALUE)) AS ARRAY(OBJECT({nested_field_name} STRING)) + ) AS VALS_ARR, + CAST( + OBJECT_CONSTRUCT('value', VALUE) AS MAP(STRING, STRING) + ) AS VALS_MAP, + CAST( + OBJECT_CONSTRUCT('value', VALUE) AS OBJECT({nested_field_name} STRING) + ) AS VALS_OBJ, + FROM + SRC + GROUP BY + ID, VALS_MAP, VALS_OBJ""" + ).collect() + + agged = structured_type_session.table(agg_table_name) + + reference = structured_type_session.sql( + """ + SELECT $1 AS ID, $2 AS TAG FROM VALUES (1, 'AB'), (2, 'B') + """ + ) + + joined = agged.join(reference, on=agged.id == reference.id, how="inner").select( + agged.id.alias("ID"), "VALS_ARR", "VALS_MAP", "VALS_OBJ", "TAG" + ) + Utils.is_schema_same(joined.schema, expected_schema, case_sensitive=False) + + try: + with mock.patch.object(context, "_enable_fix_2360274", fix_enabled): + if fix_enabled: + inner() + else: + with pytest.raises( + SnowparkSQLException, + match="Unsupported data type 'STRUCTURED_OBJECT'", + ): + inner() + finally: + Utils.drop_table(structured_type_session, agg_table_name) diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index 082e6cf45d..3ea7b0b458 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -2291,6 +2291,7 @@ def artifact_repo_test(_): @pytest.mark.skipif( sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+" ) +@pytest.mark.skip("SNOW-2362946: Skip until root cause is found.") def test_sproc_artifact_repository_from_file(session, tmpdir): source = dedent( """