Skip to content

Commit 443ea63

Browse files
committed
Harden multi-row insert casting to deterministic typed behavior.
Replace runtime-value adaptive logic with deterministic casting for SQLAlchemy multi-row VALUES bind markers only, and update compiler tests to verify multi-row casts and single-row non-cast behavior.
1 parent ffe0ddd commit 443ea63

2 files changed

Lines changed: 41 additions & 68 deletions

File tree

src/databricks/sqlalchemy/_ddl.py

Lines changed: 20 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import re
2-
from numbers import Number
32
from sqlalchemy.sql import compiler, sqltypes
43
import logging
54

@@ -166,21 +165,6 @@ def bindparam_string(self, name, **kw):
166165
return self._BIND_TEMPLATE % {"name": name.replace("`", "``")}
167166
return super().bindparam_string(name, **kw)
168167

169-
@staticmethod
170-
def _value_family(value):
171-
"""Return a coarse runtime family for adaptive multi-row cast decisions."""
172-
if value is None:
173-
return "null"
174-
if isinstance(value, bool):
175-
return "bool"
176-
if isinstance(value, str):
177-
return "string"
178-
if isinstance(value, (bytes, bytearray, memoryview)):
179-
return "binary"
180-
if isinstance(value, Number):
181-
return "number"
182-
return "other"
183-
184168
@staticmethod
185169
def _split_multivalue_bind_name(bind_name):
186170
"""Split SQLAlchemy's ``<col>_m<idx>`` bind names into (column, idx)."""
@@ -189,55 +173,37 @@ def _split_multivalue_bind_name(bind_name):
189173
return None
190174
return match.group("col"), int(match.group("idx"))
191175

192-
def _build_adaptive_cast_plan(self):
193-
"""Return {bind_name: cast_sql_type} for risky multi-row value groups.
176+
def _build_multi_value_cast_plan(self, insert_stmt):
177+
"""Return {bind_name: cast_sql_type} for multi-row VALUES insert binds.
194178
195-
We only target SQLAlchemy-generated multi-row binds (``*_mN``). For
196-
each logical column we inspect row values available at compile time and
197-
cast only when families are heterogeneous in a way that commonly causes
198-
Spark inline-table incompatibility (e.g., number + string).
179+
This is a deterministic fix for Spark inline-table type reconciliation:
180+
for SQLAlchemy-generated multi-row INSERT binds (``*_mN``), always cast
181+
the marker to the bind's dialect SQL type so each column position in the
182+
VALUES table has an explicit server-side type.
199183
"""
200-
column_bind_names = {}
201-
for bind_name, bind_param in self.binds.items():
202-
split = self._split_multivalue_bind_name(bind_name)
203-
if split is None:
204-
continue
205-
column_name, _ = split
206-
column_bind_names.setdefault(column_name, []).append((bind_name, bind_param))
184+
if not getattr(insert_stmt, "_multi_values", None):
185+
return {}
207186

208187
cast_plan = {}
209-
for bind_entries in column_bind_names.values():
210-
families = set()
211-
for _, bind_param in bind_entries:
212-
value = getattr(bind_param, "value", None)
213-
family = self._value_family(value)
214-
if family != "null":
215-
families.add(family)
216-
217-
if len(families) <= 1:
188+
for bind_name, bind_param in self.binds.items():
189+
if self._split_multivalue_bind_name(bind_name) is None:
218190
continue
219191

220-
# Numeric + numeric is safe for Spark inline tables and does not
221-
# need explicit casting.
222-
if families == {"number"}:
192+
type_engine = getattr(bind_param, "type", None)
193+
if type_engine is None or isinstance(type_engine, sqltypes.NullType):
223194
continue
224195

225-
for bind_name, bind_param in bind_entries:
226-
type_engine = getattr(bind_param, "type", None)
227-
if type_engine is None or isinstance(type_engine, sqltypes.NullType):
228-
continue
229-
230-
dialect_type = type_engine._unwrapped_dialect_impl(self.dialect)
231-
target_type = self.dialect.type_compiler_instance.process(
232-
dialect_type, identifier_preparer=self.preparer
233-
)
234-
cast_plan[bind_name] = target_type
196+
dialect_type = type_engine._unwrapped_dialect_impl(self.dialect)
197+
target_type = self.dialect.type_compiler_instance.process(
198+
dialect_type, identifier_preparer=self.preparer
199+
)
200+
cast_plan[bind_name] = target_type
235201

236202
return cast_plan
237203

238-
def _apply_adaptive_multi_value_casts(self, sql_text):
204+
def _apply_multi_value_casts(self, sql_text, insert_stmt):
239205
"""Wrap selected ``:`name``` markers with ``CAST(... AS <type>)``."""
240-
cast_plan = self._build_adaptive_cast_plan()
206+
cast_plan = self._build_multi_value_cast_plan(insert_stmt)
241207
if not cast_plan:
242208
return sql_text
243209

@@ -249,7 +215,7 @@ def _apply_adaptive_multi_value_casts(self, sql_text):
249215

250216
def visit_insert(self, insert_stmt, **kw):
251217
sql_text = super().visit_insert(insert_stmt, **kw)
252-
return self._apply_adaptive_multi_value_casts(sql_text)
218+
return self._apply_multi_value_casts(sql_text, insert_stmt)
253219

254220
def limit_clause(self, select, **kw):
255221
"""Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1,

tests/test_local/test_ddl.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,8 @@ def test_in_clause_expansion_renders_backticked_markers(self):
427427
assert ":`col-name_1_3`" in expanded.statement
428428

429429

430-
class TestAdaptiveMultiRowInsertCasts(DDLTestBase):
431-
def test_mixed_runtime_families_in_multi_values_are_cast(self):
430+
class TestMultiRowInsertCasts(DDLTestBase):
431+
def test_multi_values_casts_mixed_type_column(self):
432432
metadata = MetaData()
433433
table = Table("t", metadata, Column("name", String()), Column("value", String()))
434434
stmt = insert(table).values(
@@ -444,31 +444,38 @@ def test_mixed_runtime_families_in_multi_values_are_cast(self):
444444
assert "CAST(:`value_m0` AS STRING)" in sql
445445
assert "CAST(:`value_m1` AS STRING)" in sql
446446
assert "CAST(:`value_m2` AS STRING)" in sql
447-
# Name values are already all string/null and should remain untouched.
448-
assert "CAST(:`name_m0` AS STRING)" not in sql
449-
assert "CAST(:`name_m1` AS STRING)" not in sql
450-
assert "CAST(:`name_m2` AS STRING)" not in sql
447+
assert "CAST(:`name_m0` AS STRING)" in sql
448+
assert "CAST(:`name_m1` AS STRING)" in sql
449+
assert "CAST(:`name_m2` AS STRING)" in sql
451450

452-
def test_homogeneous_multi_values_are_not_cast(self):
451+
def test_homogeneous_multi_values_are_cast(self):
453452
metadata = MetaData()
454453
table = Table("t", metadata, Column("value", String()))
455454
stmt = insert(table).values(
456455
[{"value": "A"}, {"value": "B"}, {"value": "C"}]
457456
)
458457

459458
sql = str(stmt.compile(bind=self.engine))
460-
assert "CAST(:`value_m0` AS STRING)" not in sql
461-
assert "CAST(:`value_m1` AS STRING)" not in sql
462-
assert "CAST(:`value_m2` AS STRING)" not in sql
459+
assert "CAST(:`value_m0` AS STRING)" in sql
460+
assert "CAST(:`value_m1` AS STRING)" in sql
461+
assert "CAST(:`value_m2` AS STRING)" in sql
463462

464-
def test_numeric_family_multi_values_are_not_cast(self):
463+
def test_numeric_family_multi_values_are_cast(self):
465464
metadata = MetaData()
466465
table = Table("t", metadata, Column("score", Numeric()))
467466
stmt = insert(table).values(
468467
[{"score": 1}, {"score": 2.5}, {"score": 3}]
469468
)
470469

471470
sql = str(stmt.compile(bind=self.engine))
472-
assert "CAST(:`score_m0` AS DECIMAL)" not in sql
473-
assert "CAST(:`score_m1` AS DECIMAL)" not in sql
474-
assert "CAST(:`score_m2` AS DECIMAL)" not in sql
471+
assert "CAST(:`score_m0` AS DECIMAL)" in sql
472+
assert "CAST(:`score_m1` AS DECIMAL)" in sql
473+
assert "CAST(:`score_m2` AS DECIMAL)" in sql
474+
475+
def test_single_row_insert_does_not_render_casts(self):
476+
metadata = MetaData()
477+
table = Table("t", metadata, Column("value", String()))
478+
stmt = insert(table).values({"value": "A"})
479+
480+
sql = str(stmt.compile(bind=self.engine))
481+
assert "CAST(:`value` AS STRING)" not in sql

0 commit comments

Comments
 (0)