Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@
- Fixed SQL Server query input failure due to incorrect select query generation.
- Fixed UDTF ingestion not preserving column nullability in the output schema.
- Fixed an issue that caused the program to hang during multithreaded Parquet based ingestion when a data fetching error occurred.
- Fixed a bug in schema parsing when custom schema strings used upper-cased data type names (NUMERIC, NUMBER, DECIMAL, VARCHAR, STRING, TEXT).
- Fixed a bug in `Session.create_dataframe` where schema string parsing failed when using upper-cased data type names (e.g., NUMERIC, NUMBER, DECIMAL, VARCHAR, STRING, TEXT).

#### Improvements

Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,11 +1089,11 @@ def get_data_type_string_object_mappings(


DECIMAL_RE = re.compile(
r"^\s*(numeric|number|decimal)\s*\(\s*(\s*)(\d*)\s*,\s*(\d*)\s*\)\s*$"
r"(?i)^\s*(numeric|number|decimal)\s*\(\s*(\s*)(\d*)\s*,\s*(\d*)\s*\)\s*$"
)
# support type string format like " decimal ( 2 , 1 ) "

STRING_RE = re.compile(r"^\s*(varchar|string|text)\s*\(\s*(\d*)\s*\)\s*$")
STRING_RE = re.compile(r"(?i)^\s*(varchar|string|text)\s*\(\s*(\d*)\s*\)\s*$")
# support type string format like " string ( 23 ) "

ARRAY_RE = re.compile(r"(?i)^\s*array\s*<")
Expand Down
20 changes: 15 additions & 5 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6844,23 +6844,28 @@ def test_create_dataframe_implicit_struct_not_null_single(session):
assert result == expected_rows


def test_create_dataframe_implicit_struct_not_null_multiple(session):
@pytest.mark.parametrize("upper_case", [False, True])
def test_create_dataframe_implicit_struct_not_null_multiple(session, upper_case):
"""
Test a schema with multiple fields, one of which is NOT NULL.
"""
data = [
[10, "foo"],
[20, "bar"],
]
schema_str = "col1: int not null, col2: string"
# Only uppercase the types, not field names
if upper_case:
schema_str = "col1: INT NOT NULL, col2: STRING(100)"
else:
schema_str = "col1: int not null, col2: string(100)"

df = session.create_dataframe(data, schema=schema_str)
# Verify schema
assert len(df.schema.fields) == 2

expected_fields = [
StructField("COL1", LongType(), nullable=False),
StructField("COL2", StringType(2**24), nullable=True),
StructField("COL2", StringType(100), nullable=True),
]
assert df.schema.fields == expected_fields

Expand All @@ -6873,15 +6878,20 @@ def test_create_dataframe_implicit_struct_not_null_multiple(session):
assert result == expected_rows


def test_create_dataframe_implicit_struct_not_null_nested(session):
@pytest.mark.parametrize("upper_case", [False, True])
def test_create_dataframe_implicit_struct_not_null_nested(session, upper_case):
"""
Test a schema with nested array and a NOT NULL decimal field.
"""
data = [
[["1", "2"], Decimal("3.14")],
[["5", "6"], Decimal("2.72")],
]
schema_str = "arr: array<string>, val: decimal(10,2) NOT NULL"
# Only uppercase the types, not field names
if upper_case:
schema_str = "arr: ARRAY<STRING>, val: DECIMAL(10,2) NOT NULL"
else:
schema_str = "arr: array<string>, val: decimal(10,2) NOT NULL"

df = session.create_dataframe(data, schema=schema_str)
# Verify schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -961,13 +961,13 @@ def unknown_dbms_create_connection():
)


SQLITE3_DB_CUSTOM_SCHEMA_STRING = "id INTEGER, int_col INTEGER, real_col FLOAT, text_col STRING, blob_col BINARY, null_col STRING, ts_col TIMESTAMP, date_col DATE, time_col TIME, short_col SHORT, long_col LONG, double_col DOUBLE, decimal_col DECIMAL, map_col MAP, array_col ARRAY, var_col VARIANT"
SQLITE3_DB_CUSTOM_SCHEMA_STRING = "id INTEGER, int_col INTEGER, real_col FLOAT, text_col STRING, blob_col BINARY, null_col TEXT(200), ts_col TIMESTAMP, date_col DATE, time_col TIME, short_col SHORT, long_col LONG, double_col DOUBLE, decimal_col DECIMAL(25,8), map_col MAP, array_col ARRAY, var_col VARIANT"
SQLITE3_DB_CUSTOM_SCHEMA_STRUCT_TYPE = StructType(
[
StructField("id", IntegerType()),
StructField("int_col", IntegerType()),
StructField("real_col", FloatType()),
StructField("text_col", StringType()),
StructField("text_col", StringType(200)),
StructField("blob_col", BinaryType()),
StructField("null_col", NullType()),
StructField("ts_col", TimestampType()),
Expand All @@ -976,7 +976,7 @@ def unknown_dbms_create_connection():
StructField("short_col", ShortType()),
StructField("long_col", LongType()),
StructField("double_col", DoubleType()),
StructField("decimal_col", DecimalType()),
StructField("decimal_col", DecimalType(25, 8)),
StructField("map_col", MapType()),
StructField("array_col", ArrayType()),
StructField("var_col", VariantType()),
Expand Down
Loading
Loading