Skip to content

Commit b1d7ab2

Browse files
committed
Fix mixed-type pandas multi-row inserts with targeted casts.
Cast only mixed scalar bind groups in SQLAlchemy-generated multi-row INSERT VALUES statements so Spark resolves inline table columns consistently. Keep homogeneous, complex, and custom bind-expression types unchanged, and add regression coverage for PECOBLR-2746 plus advertised SQLAlchemy scalar and complex types. Signed-off-by: Madhavendra Rathore <madhavendra.rathore@databricks.com>
1 parent 3ab18ed commit b1d7ab2

4 files changed

Lines changed: 582 additions & 18 deletions

File tree

src/databricks/sqlalchemy/_ddl.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import re
2+
from datetime import date, datetime, time
3+
from numbers import Number
4+
from uuid import UUID
25
from sqlalchemy.sql import compiler, sqltypes
36
import logging
47

@@ -165,6 +168,124 @@ def bindparam_string(self, name, **kw):
165168
return self._BIND_TEMPLATE % {"name": name.replace("`", "``")}
166169
return super().bindparam_string(name, **kw)
167170

171+
@staticmethod
172+
def _split_multivalue_bind_name(bind_name):
173+
"""Split SQLAlchemy's ``<col>_m<idx>`` bind names into (column, idx)."""
174+
match = re.match(r"^(?P<col>.+)_m(?P<idx>\d+)$", bind_name)
175+
if not match:
176+
return None
177+
return match.group("col"), int(match.group("idx"))
178+
179+
@staticmethod
180+
def _value_family(value):
181+
"""Return scalar value family; ``None`` means non-scalar/unsupported."""
182+
if value is None:
183+
return "null"
184+
if isinstance(value, bool):
185+
return "bool"
186+
if isinstance(value, Number):
187+
return "number"
188+
if isinstance(value, str):
189+
return "string"
190+
if isinstance(value, (bytes, bytearray, memoryview)):
191+
return "binary"
192+
if isinstance(value, (date, time, datetime)):
193+
return "temporal"
194+
if isinstance(value, UUID):
195+
return "uuid"
196+
return None
197+
198+
@staticmethod
199+
def _has_custom_bind_expression(type_engine):
200+
"""True if the type (or its impl) customizes bind-expression rendering."""
201+
type_cls = type(type_engine)
202+
if (
203+
getattr(type_cls, "bind_expression", None)
204+
is not sqltypes.TypeEngine.bind_expression
205+
):
206+
return True
207+
208+
impl = getattr(type_engine, "impl", None)
209+
if impl is not None:
210+
impl_cls = type(impl)
211+
if (
212+
getattr(impl_cls, "bind_expression", None)
213+
is not sqltypes.TypeEngine.bind_expression
214+
):
215+
return True
216+
return False
217+
218+
def _build_multi_value_cast_plan(self, insert_stmt):
219+
"""Return {bind_name: cast_sql_type} for multi-row VALUES insert binds.
220+
221+
Cast only *mixed scalar* multi-row bind groups. This avoids breaking
222+
complex/custom bind types (e.g. ARRAY/MAP/VARIANT) while still fixing
223+
Spark inline-table incompatibility for object columns that mix
224+
primitive families (e.g. INT + STRING).
225+
"""
226+
if not getattr(insert_stmt, "_multi_values", None):
227+
return {}
228+
229+
grouped_binds = {}
230+
for bind_name, bind_param in self.binds.items():
231+
split = self._split_multivalue_bind_name(bind_name)
232+
if split is None:
233+
continue
234+
column_name, _ = split
235+
grouped_binds.setdefault(column_name, []).append((bind_name, bind_param))
236+
237+
cast_plan = {}
238+
for bind_entries in grouped_binds.values():
239+
families = set()
240+
has_non_scalar = False
241+
has_custom_bind_expression = False
242+
243+
for _, bind_param in bind_entries:
244+
value_family = self._value_family(getattr(bind_param, "value", None))
245+
if value_family is None:
246+
has_non_scalar = True
247+
break
248+
if value_family != "null":
249+
families.add(value_family)
250+
251+
type_engine = getattr(bind_param, "type", None)
252+
if type_engine is not None and self._has_custom_bind_expression(
253+
type_engine
254+
):
255+
has_custom_bind_expression = True
256+
257+
if has_non_scalar or has_custom_bind_expression or len(families) <= 1:
258+
continue
259+
260+
for bind_name, bind_param in bind_entries:
261+
type_engine = getattr(bind_param, "type", None)
262+
if type_engine is None or isinstance(type_engine, sqltypes.NullType):
263+
continue
264+
265+
dialect_type = type_engine._unwrapped_dialect_impl(self.dialect)
266+
target_type = self.dialect.type_compiler_instance.process(
267+
dialect_type, identifier_preparer=self.preparer
268+
)
269+
cast_plan[bind_name] = target_type
270+
271+
return cast_plan
272+
273+
def _apply_multi_value_casts(self, sql_text, insert_stmt):
274+
"""Wrap selected ``:`name``` markers with ``CAST(... AS <type>)``."""
275+
cast_plan = self._build_multi_value_cast_plan(insert_stmt)
276+
if not cast_plan:
277+
return sql_text
278+
279+
rendered = sql_text
280+
for bind_name, target_type in cast_plan.items():
281+
marker = self._BIND_TEMPLATE % {"name": bind_name.replace("`", "``")}
282+
rendered = rendered.replace(marker, f"CAST({marker} AS {target_type})")
283+
return rendered
284+
285+
def visit_insert(self, insert_stmt, **kw):
286+
sql_text = super().visit_insert(insert_stmt, **kw)
287+
return self._apply_multi_value_casts(sql_text, insert_stmt)
288+
168289
def limit_clause(self, select, **kw):
169290
"""Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1,
170291
since Databricks SQL doesn't support the latter.

tests/test_local/e2e/test_complex_types.py

Lines changed: 116 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
DateTime,
1212
)
1313
from collections.abc import Sequence
14-
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap, DatabricksVariant
14+
from databricks.sqlalchemy import (
15+
TIMESTAMP,
16+
TINYINT,
17+
DatabricksArray,
18+
DatabricksMap,
19+
DatabricksVariant,
20+
)
1521
from sqlalchemy.orm import DeclarativeBase, Session
1622
from sqlalchemy import select
1723
from datetime import date, datetime, time, timedelta, timezone
@@ -20,6 +26,7 @@
2026
import decimal
2127
import json
2228

29+
2330
class TestComplexTypes(TestSetup):
2431
def _parse_to_common_type(self, value):
2532
"""
@@ -175,8 +182,8 @@ class VariantTable(Base):
175182
"number": 123,
176183
"boolean": True,
177184
"array": [1, 2, 3],
178-
"object": {"nested": "value"}
179-
}
185+
"object": {"nested": "value"},
186+
},
180187
}
181188

182189
return VariantTable, sample_data
@@ -239,6 +246,44 @@ def test_map_table_creation_pandas(self):
239246
df_result = pd.read_sql(stmt, engine)
240247
assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data)
241248

249+
def test_array_table_creation_pandas_multi(self):
250+
table, sample_data = self.sample_array_table()
251+
252+
with self.table_context(table) as engine:
253+
df = pd.DataFrame([sample_data, sample_data | {"int_col": 2}])
254+
df.to_sql(
255+
table.__tablename__,
256+
engine,
257+
if_exists="append",
258+
index=False,
259+
method="multi",
260+
)
261+
262+
stmt = select(table).order_by(table.int_col)
263+
df_result = pd.read_sql(stmt, engine)
264+
assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data)
265+
expected_second = sample_data | {"int_col": 2}
266+
assert self._recursive_compare(df_result.iloc[1].to_dict(), expected_second)
267+
268+
def test_map_table_creation_pandas_multi(self):
269+
table, sample_data = self.sample_map_table()
270+
271+
with self.table_context(table) as engine:
272+
df = pd.DataFrame([sample_data, sample_data | {"int_col": 2}])
273+
df.to_sql(
274+
table.__tablename__,
275+
engine,
276+
if_exists="append",
277+
index=False,
278+
method="multi",
279+
)
280+
281+
stmt = select(table).order_by(table.int_col)
282+
df_result = pd.read_sql(stmt, engine)
283+
assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data)
284+
expected_second = sample_data | {"int_col": 2}
285+
assert self._recursive_compare(df_result.iloc[1].to_dict(), expected_second)
286+
242287
def test_insert_variant_table_sqlalchemy(self):
243288
table, sample_data = self.sample_variant_table()
244289

@@ -253,7 +298,12 @@ def test_insert_variant_table_sqlalchemy(self):
253298
result = session.scalar(stmt)
254299
compare = {key: getattr(result, key) for key in sample_data.keys()}
255300
# Parse JSON values back to original format for comparison
256-
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
301+
for key in [
302+
"variant_simple_col",
303+
"variant_nested_col",
304+
"variant_array_col",
305+
"variant_mixed_col",
306+
]:
257307
if compare[key] is not None:
258308
compare[key] = json.loads(compare[key])
259309

@@ -263,26 +313,76 @@ def test_variant_table_creation_pandas(self):
263313
table, sample_data = self.sample_variant_table()
264314

265315
with self.table_context(table) as engine:
266-
316+
267317
df = pd.DataFrame([sample_data])
268318
dtype_mapping = {
269319
"variant_simple_col": DatabricksVariant,
270320
"variant_nested_col": DatabricksVariant,
271321
"variant_array_col": DatabricksVariant,
272-
"variant_mixed_col": DatabricksVariant
322+
"variant_mixed_col": DatabricksVariant,
273323
}
274-
df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping)
275-
324+
df.to_sql(
325+
table.__tablename__,
326+
engine,
327+
if_exists="append",
328+
index=False,
329+
dtype=dtype_mapping,
330+
)
331+
276332
stmt = select(table)
277333
df_result = pd.read_sql(stmt, engine)
278334
result_dict = df_result.iloc[0].to_dict()
279335
# Parse JSON values back to original format for comparison
280-
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
336+
for key in [
337+
"variant_simple_col",
338+
"variant_nested_col",
339+
"variant_array_col",
340+
"variant_mixed_col",
341+
]:
281342
if result_dict[key] is not None:
282343
result_dict[key] = json.loads(result_dict[key])
283344

284345
assert result_dict == sample_data
285346

347+
def test_variant_table_creation_pandas_multi(self):
348+
table, sample_data = self.sample_variant_table()
349+
350+
with self.table_context(table) as engine:
351+
second = sample_data | {"int_col": 2}
352+
df = pd.DataFrame([sample_data, second])
353+
dtype_mapping = {
354+
"variant_simple_col": DatabricksVariant,
355+
"variant_nested_col": DatabricksVariant,
356+
"variant_array_col": DatabricksVariant,
357+
"variant_mixed_col": DatabricksVariant,
358+
}
359+
df.to_sql(
360+
table.__tablename__,
361+
engine,
362+
if_exists="append",
363+
index=False,
364+
dtype=dtype_mapping,
365+
method="multi",
366+
)
367+
368+
stmt = select(table).order_by(table.int_col)
369+
df_result = pd.read_sql(stmt, engine)
370+
first_row = df_result.iloc[0].to_dict()
371+
second_row = df_result.iloc[1].to_dict()
372+
for key in [
373+
"variant_simple_col",
374+
"variant_nested_col",
375+
"variant_array_col",
376+
"variant_mixed_col",
377+
]:
378+
if first_row[key] is not None:
379+
first_row[key] = json.loads(first_row[key])
380+
if second_row[key] is not None:
381+
second_row[key] = json.loads(second_row[key])
382+
383+
assert first_row == sample_data
384+
assert second_row == second
385+
286386
def test_variant_literal_processor(self):
287387
table, sample_data = self.sample_variant_table()
288388

@@ -291,8 +391,7 @@ def test_variant_literal_processor(self):
291391

292392
try:
293393
compiled = stmt.compile(
294-
dialect=engine.dialect,
295-
compile_kwargs={"literal_binds": True}
394+
dialect=engine.dialect, compile_kwargs={"literal_binds": True}
296395
)
297396
sql_str = str(compiled)
298397

@@ -311,7 +410,12 @@ def test_variant_literal_processor(self):
311410
compare = {key: getattr(result, key) for key in sample_data.keys()}
312411

313412
# Parse JSON values back to original Python objects
314-
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
413+
for key in [
414+
"variant_simple_col",
415+
"variant_nested_col",
416+
"variant_array_col",
417+
"variant_mixed_col",
418+
]:
315419
if compare[key] is not None:
316420
compare[key] = json.loads(compare[key])
317421

0 commit comments

Comments
 (0)