Skip to content

Commit b702520

Browse files
authored
Merge pull request #68 from databricks/fix/pecoblr-2746-adaptive-multirow-casts
[PECOBLR-2746]Fix pandas multi-row mixed-type inserts with adaptive casts
2 parents 7e74a3f + 4ae2f03 commit b702520

6 files changed

Lines changed: 738 additions & 18 deletions

File tree

README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,42 @@ engine = create_engine(
4646
)
4747
```
4848

49+
### Connection URL parameters and `connect_args`
50+
51+
The Databricks SQLAlchemy dialect accepts dialect-specific options in the
52+
SQLAlchemy connection URL query string:
53+
54+
| Parameter | Required | Default | Description |
55+
|-|-|-|-|
56+
| `http_path` | Yes | | HTTP path for the Databricks SQL warehouse or compute resource. |
57+
| `catalog` | Yes | | Initial catalog for the connection. |
58+
| `schema` | Yes | | Initial schema for the connection. |
59+
| `enable_multirow_insert_casts` | No | `true` | Enables targeted casts for mixed scalar values in SQLAlchemy-generated multi-row `INSERT ... VALUES` statements. This avoids Spark inline-table type errors for pandas `to_sql(method="multi")` with mixed scalar/object columns. Set to `false` to disable this rewrite. |
60+
61+
For example, to disable targeted multi-row insert casts:
62+
63+
```python
64+
engine = create_engine(
65+
"databricks://token:dapi***@***.cloud.databricks.com"
66+
"?http_path=***&catalog=main&schema=test"
67+
"&enable_multirow_insert_casts=false"
68+
)
69+
```
70+
71+
Use SQLAlchemy's `connect_args` for DBAPI connection options that should be
72+
passed through to `databricks-sql-connector`, such as user-agent settings:
73+
74+
```python
75+
engine = create_engine(
76+
"databricks://token:dapi***@***.cloud.databricks.com"
77+
"?http_path=***&catalog=main&schema=test",
78+
connect_args={"user_agent_entry": "My SQLAlchemy App"},
79+
)
80+
```
81+
82+
Dialect URL parameters control SQLAlchemy compilation behavior and are not
83+
forwarded to the DBAPI connector.
84+
4985
## Types
5086

5187
The [SQLAlchemy type hierarchy](https://docs.sqlalchemy.org/en/20/core/type_basics.html) contains backend-agnostic type implementations (represented in CamelCase) and backend-specific types (represented in UPPERCASE). The majority of SQLAlchemy's [CamelCase](https://docs.sqlalchemy.org/en/20/core/type_basics.html#the-camelcase-datatypes) types are supported. This means that a SQLAlchemy application using these types should "just work" with Databricks.

src/databricks/sqlalchemy/_ddl.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import re
2+
from datetime import date, datetime, time
3+
from numbers import Number
4+
from uuid import UUID
25
from sqlalchemy.sql import compiler, sqltypes
36
import logging
47

@@ -165,6 +168,138 @@ def bindparam_string(self, name, **kw):
165168
return self._BIND_TEMPLATE % {"name": name.replace("`", "``")}
166169
return super().bindparam_string(name, **kw)
167170

171+
@staticmethod
172+
def _split_multivalue_bind_name(bind_name):
173+
"""Split SQLAlchemy's ``<col>_m<idx>`` bind names into (column, idx)."""
174+
match = re.match(r"^(?P<col>.+)_m(?P<idx>\d+)$", bind_name)
175+
if not match:
176+
return None
177+
return match.group("col"), int(match.group("idx"))
178+
179+
@staticmethod
180+
def _value_family(value):
181+
"""Return scalar value family; ``None`` means non-scalar/unsupported."""
182+
if value is None:
183+
return "null"
184+
if isinstance(value, bool):
185+
return "bool"
186+
if isinstance(value, Number):
187+
return "number"
188+
if isinstance(value, str):
189+
return "string"
190+
if isinstance(value, (bytes, bytearray, memoryview)):
191+
return "binary"
192+
if isinstance(value, (date, time, datetime)):
193+
return "temporal"
194+
if isinstance(value, UUID):
195+
return "uuid"
196+
return None
197+
198+
@staticmethod
199+
def _has_custom_bind_expression(type_engine):
200+
"""True if the type (or its impl) customizes bind-expression rendering."""
201+
type_cls = type(type_engine)
202+
if (
203+
getattr(type_cls, "bind_expression", None)
204+
is not sqltypes.TypeEngine.bind_expression
205+
):
206+
return True
207+
208+
impl = getattr(type_engine, "impl", None)
209+
if impl is not None:
210+
impl_cls = type(impl)
211+
if (
212+
getattr(impl_cls, "bind_expression", None)
213+
is not sqltypes.TypeEngine.bind_expression
214+
):
215+
return True
216+
return False
217+
218+
def _build_multi_value_cast_plan(self, insert_stmt):
219+
"""Return {bind_name: cast_sql_type} for multi-row VALUES insert binds.
220+
221+
Cast only *mixed scalar* multi-row bind groups whose SQLAlchemy target
222+
type compiles to STRING. This avoids silent data loss for non-string
223+
target columns and avoids breaking complex/custom bind types (e.g.
224+
ARRAY/MAP/VARIANT), while still fixing Spark inline-table
225+
incompatibility for object columns that mix primitive families into a
226+
string-like target column.
227+
"""
228+
if not self.dialect.enable_multirow_insert_casts:
229+
return {}
230+
231+
if not getattr(insert_stmt, "_multi_values", None):
232+
return {}
233+
234+
grouped_binds = {}
235+
for bind_name, bind_param in self.binds.items():
236+
split = self._split_multivalue_bind_name(bind_name)
237+
if split is None:
238+
continue
239+
column_name, _ = split
240+
grouped_binds.setdefault(column_name, []).append((bind_name, bind_param))
241+
242+
cast_plan = {}
243+
for bind_entries in grouped_binds.values():
244+
families = set()
245+
has_non_scalar = False
246+
has_custom_bind_expression = False
247+
248+
for _, bind_param in bind_entries:
249+
value_family = self._value_family(getattr(bind_param, "value", None))
250+
if value_family is None:
251+
has_non_scalar = True
252+
break
253+
if value_family != "null":
254+
families.add(value_family)
255+
256+
type_engine = getattr(bind_param, "type", None)
257+
if type_engine is not None and self._has_custom_bind_expression(
258+
type_engine
259+
):
260+
has_custom_bind_expression = True
261+
262+
if has_non_scalar or has_custom_bind_expression or len(families) <= 1:
263+
continue
264+
265+
bind_targets = []
266+
for bind_name, bind_param in bind_entries:
267+
type_engine = getattr(bind_param, "type", None)
268+
if type_engine is None or isinstance(type_engine, sqltypes.NullType):
269+
continue
270+
271+
dialect_type = type_engine._unwrapped_dialect_impl(self.dialect)
272+
target_type = self.dialect.type_compiler_instance.process(
273+
dialect_type, identifier_preparer=self.preparer
274+
)
275+
bind_targets.append((bind_name, target_type))
276+
277+
if not bind_targets or any(
278+
target_type.upper() != "STRING" for _, target_type in bind_targets
279+
):
280+
continue
281+
282+
for bind_name, target_type in bind_targets:
283+
cast_plan[bind_name] = target_type
284+
285+
return cast_plan
286+
287+
def _apply_multi_value_casts(self, sql_text, insert_stmt):
288+
"""Wrap selected ``:`name``` markers with ``CAST(... AS <type>)``."""
289+
cast_plan = self._build_multi_value_cast_plan(insert_stmt)
290+
if not cast_plan:
291+
return sql_text
292+
293+
rendered = sql_text
294+
for bind_name, target_type in cast_plan.items():
295+
marker = self._BIND_TEMPLATE % {"name": bind_name.replace("`", "``")}
296+
rendered = rendered.replace(marker, f"CAST({marker} AS {target_type})")
297+
return rendered
298+
299+
def visit_insert(self, insert_stmt, **kw):
300+
sql_text = super().visit_insert(insert_stmt, **kw)
301+
return self._apply_multi_value_casts(sql_text, insert_stmt)
302+
168303
def limit_clause(self, select, **kw):
169304
"""Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1,
170305
since Databricks SQL doesn't support the latter.

src/databricks/sqlalchemy/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ class DatabricksImpl(DefaultImpl):
4242
logger = logging.getLogger(__name__)
4343

4444

45+
def _parse_bool_url_param(value: Optional[str], default: bool) -> bool:
46+
if value is None:
47+
return default
48+
if value.lower() in ("1", "true", "yes", "on"):
49+
return True
50+
if value.lower() in ("0", "false", "no", "off"):
51+
return False
52+
return default
53+
54+
4555
class DatabricksDialect(default.DefaultDialect):
4656
"""This dialect implements only those methods required to pass our e2e tests"""
4757

@@ -65,6 +75,7 @@ class DatabricksDialect(default.DefaultDialect):
6575
supports_server_side_cursors: bool = False
6676
supports_sequences: bool = False
6777
supports_native_boolean: bool = True
78+
enable_multirow_insert_casts: bool = True
6879

6980
colspecs = {
7081
sqlalchemy.types.DateTime: dialect_type_impl.TIMESTAMP_NTZ,
@@ -117,6 +128,9 @@ def create_connect_args(self, url):
117128

118129
self.schema = kwargs["schema"]
119130
self.catalog = kwargs["catalog"]
131+
self.enable_multirow_insert_casts = _parse_bool_url_param(
132+
url.query.get("enable_multirow_insert_casts"), True
133+
)
120134

121135
self._force_paramstyle_to_native_mode()
122136

0 commit comments

Comments
 (0)