Skip to content

Commit ffe0ddd

Browse files
committed
Fix mixed-type pandas multi-row inserts with adaptive bind casts.
Target cast rendering only for heterogeneous INSERT ... VALUES multi-row bind groups to prevent Spark inline-table type incompatibility, and add compile-time plus e2e tests to validate behavior in main.default.
1 parent 3ab18ed commit ffe0ddd

3 files changed

Lines changed: 216 additions & 0 deletions

File tree

src/databricks/sqlalchemy/_ddl.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from numbers import Number
23
from sqlalchemy.sql import compiler, sqltypes
34
import logging
45

@@ -165,6 +166,91 @@ def bindparam_string(self, name, **kw):
165166
return self._BIND_TEMPLATE % {"name": name.replace("`", "``")}
166167
return super().bindparam_string(name, **kw)
167168

169+
@staticmethod
170+
def _value_family(value):
171+
"""Return a coarse runtime family for adaptive multi-row cast decisions."""
172+
if value is None:
173+
return "null"
174+
if isinstance(value, bool):
175+
return "bool"
176+
if isinstance(value, str):
177+
return "string"
178+
if isinstance(value, (bytes, bytearray, memoryview)):
179+
return "binary"
180+
if isinstance(value, Number):
181+
return "number"
182+
return "other"
183+
184+
@staticmethod
185+
def _split_multivalue_bind_name(bind_name):
186+
"""Split SQLAlchemy's ``<col>_m<idx>`` bind names into (column, idx)."""
187+
match = re.match(r"^(?P<col>.+)_m(?P<idx>\d+)$", bind_name)
188+
if not match:
189+
return None
190+
return match.group("col"), int(match.group("idx"))
191+
192+
def _build_adaptive_cast_plan(self):
193+
"""Return {bind_name: cast_sql_type} for risky multi-row value groups.
194+
195+
We only target SQLAlchemy-generated multi-row binds (``*_mN``). For
196+
each logical column we inspect row values available at compile time and
197+
cast only when families are heterogeneous in a way that commonly causes
198+
Spark inline-table incompatibility (e.g., number + string).
199+
"""
200+
column_bind_names = {}
201+
for bind_name, bind_param in self.binds.items():
202+
split = self._split_multivalue_bind_name(bind_name)
203+
if split is None:
204+
continue
205+
column_name, _ = split
206+
column_bind_names.setdefault(column_name, []).append((bind_name, bind_param))
207+
208+
cast_plan = {}
209+
for bind_entries in column_bind_names.values():
210+
families = set()
211+
for _, bind_param in bind_entries:
212+
value = getattr(bind_param, "value", None)
213+
family = self._value_family(value)
214+
if family != "null":
215+
families.add(family)
216+
217+
if len(families) <= 1:
218+
continue
219+
220+
# Numeric + numeric is safe for Spark inline tables and does not
221+
# need explicit casting.
222+
if families == {"number"}:
223+
continue
224+
225+
for bind_name, bind_param in bind_entries:
226+
type_engine = getattr(bind_param, "type", None)
227+
if type_engine is None or isinstance(type_engine, sqltypes.NullType):
228+
continue
229+
230+
dialect_type = type_engine._unwrapped_dialect_impl(self.dialect)
231+
target_type = self.dialect.type_compiler_instance.process(
232+
dialect_type, identifier_preparer=self.preparer
233+
)
234+
cast_plan[bind_name] = target_type
235+
236+
return cast_plan
237+
238+
def _apply_adaptive_multi_value_casts(self, sql_text):
239+
"""Wrap selected ``:`name``` markers with ``CAST(... AS <type>)``."""
240+
cast_plan = self._build_adaptive_cast_plan()
241+
if not cast_plan:
242+
return sql_text
243+
244+
rendered = sql_text
245+
for bind_name, target_type in cast_plan.items():
246+
marker = self._BIND_TEMPLATE % {"name": bind_name.replace("`", "``")}
247+
rendered = rendered.replace(marker, f"CAST({marker} AS {target_type})")
248+
return rendered
249+
250+
def visit_insert(self, insert_stmt, **kw):
251+
sql_text = super().visit_insert(insert_stmt, **kw)
252+
return self._apply_adaptive_multi_value_casts(sql_text)
253+
168254
def limit_clause(self, select, **kw):
169255
"""Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1,
170256
since Databricks SQL doesn't support the latter.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import uuid
2+
3+
import pandas as pd
4+
import pytest
5+
from sqlalchemy import create_engine, text
6+
from sqlalchemy.engine import Engine
7+
8+
9+
@pytest.fixture
10+
def db_engine(connection_details) -> Engine:
11+
host = connection_details["host"]
12+
http_path = connection_details["http_path"]
13+
access_token = connection_details["access_token"]
14+
catalog = connection_details["catalog"]
15+
schema = connection_details["schema"]
16+
17+
conn_string = (
18+
f"databricks://token:{access_token}@{host}"
19+
f"?http_path={http_path}&catalog={catalog}&schema={schema}"
20+
)
21+
engine = create_engine(
22+
conn_string, connect_args={"_user_agent_entry": "SQLAlchemy pandas e2e tests"}
23+
)
24+
try:
25+
yield engine
26+
finally:
27+
engine.dispose()
28+
29+
30+
def test_pandas_to_sql_multi_mixed_object_column_succeeds(db_engine: Engine):
31+
table_name = f"pecoblr_2746_e2e_{uuid.uuid4().hex[:8]}"
32+
fq_table_name = f"`main`.`default`.`{table_name}`"
33+
df = pd.DataFrame(
34+
{
35+
"name": ["alice", "bob", None],
36+
"value": [1, 0, "NE"],
37+
"score": [9.5, 8.1, None],
38+
"active": [True, None, False],
39+
}
40+
)
41+
42+
try:
43+
with db_engine.begin() as conn:
44+
conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}"))
45+
conn.execute(
46+
text(
47+
f"CREATE TABLE {fq_table_name} "
48+
"(name STRING, value STRING, score DOUBLE, active BOOLEAN) "
49+
"USING DELTA"
50+
)
51+
)
52+
53+
# This is the failing path from PECOBLR-2746 before the adaptive cast fix.
54+
df.to_sql(
55+
table_name, db_engine, schema="default", if_exists="append", index=False, method="multi"
56+
)
57+
58+
with db_engine.begin() as conn:
59+
rows = conn.execute(
60+
text(
61+
f"SELECT name, value, score, active FROM {fq_table_name} "
62+
"ORDER BY CASE WHEN name IS NULL THEN 1 ELSE 0 END, name"
63+
)
64+
).fetchall()
65+
66+
assert len(rows) == 3
67+
assert rows[0][0] == "alice"
68+
assert rows[0][1] == "1"
69+
assert rows[0][2] == pytest.approx(9.5)
70+
assert rows[0][3] is True
71+
72+
assert rows[1][0] == "bob"
73+
assert rows[1][1] == "0"
74+
assert rows[1][2] == pytest.approx(8.1)
75+
assert rows[1][3] is None
76+
77+
assert rows[2][0] is None
78+
assert rows[2][1] == "NE"
79+
assert rows[2][2] is None
80+
assert rows[2][3] is False
81+
finally:
82+
with db_engine.begin() as conn:
83+
conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}"))

tests/test_local/test_ddl.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,50 @@ def test_in_clause_expansion_renders_backticked_markers(self):
425425
assert ":`col-name_1_1`" in expanded.statement
426426
assert ":`col-name_1_2`" in expanded.statement
427427
assert ":`col-name_1_3`" in expanded.statement
428+
429+
430+
class TestAdaptiveMultiRowInsertCasts(DDLTestBase):
431+
def test_mixed_runtime_families_in_multi_values_are_cast(self):
432+
metadata = MetaData()
433+
table = Table("t", metadata, Column("name", String()), Column("value", String()))
434+
stmt = insert(table).values(
435+
[
436+
{"name": "alice", "value": 1},
437+
{"name": "bob", "value": 0},
438+
{"name": None, "value": "NE"},
439+
]
440+
)
441+
442+
sql = str(stmt.compile(bind=self.engine))
443+
444+
assert "CAST(:`value_m0` AS STRING)" in sql
445+
assert "CAST(:`value_m1` AS STRING)" in sql
446+
assert "CAST(:`value_m2` AS STRING)" in sql
447+
# Name values are already all string/null and should remain untouched.
448+
assert "CAST(:`name_m0` AS STRING)" not in sql
449+
assert "CAST(:`name_m1` AS STRING)" not in sql
450+
assert "CAST(:`name_m2` AS STRING)" not in sql
451+
452+
def test_homogeneous_multi_values_are_not_cast(self):
453+
metadata = MetaData()
454+
table = Table("t", metadata, Column("value", String()))
455+
stmt = insert(table).values(
456+
[{"value": "A"}, {"value": "B"}, {"value": "C"}]
457+
)
458+
459+
sql = str(stmt.compile(bind=self.engine))
460+
assert "CAST(:`value_m0` AS STRING)" not in sql
461+
assert "CAST(:`value_m1` AS STRING)" not in sql
462+
assert "CAST(:`value_m2` AS STRING)" not in sql
463+
464+
def test_numeric_family_multi_values_are_not_cast(self):
465+
metadata = MetaData()
466+
table = Table("t", metadata, Column("score", Numeric()))
467+
stmt = insert(table).values(
468+
[{"score": 1}, {"score": 2.5}, {"score": 3}]
469+
)
470+
471+
sql = str(stmt.compile(bind=self.engine))
472+
assert "CAST(:`score_m0` AS DECIMAL)" not in sql
473+
assert "CAST(:`score_m1` AS DECIMAL)" not in sql
474+
assert "CAST(:`score_m2` AS DECIMAL)" not in sql

0 commit comments

Comments
 (0)