Skip to content

Commit aa0d69b

Browse files
committed
Add opt-out gate for multi-row insert casts.
Allow users to disable targeted multi-row insert cast rendering with enable_multirow_insert_casts=false in the Databricks SQLAlchemy engine URL while keeping the PECOBLR-2746 fix enabled by default. Signed-off-by: Madhavendra Rathore <madhavendra.rathore@databricks.com>
1 parent b1d7ab2 commit aa0d69b

3 files changed

Lines changed: 29 additions & 0 deletions

File tree

src/databricks/sqlalchemy/_ddl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ def _build_multi_value_cast_plan(self, insert_stmt):
223223
Spark inline-table incompatibility for object columns that mix
224224
primitive families (e.g. INT + STRING).
225225
"""
226+
if not self.dialect.enable_multirow_insert_casts:
227+
return {}
228+
226229
if not getattr(insert_stmt, "_multi_values", None):
227230
return {}
228231

src/databricks/sqlalchemy/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ class DatabricksImpl(DefaultImpl):
4242
logger = logging.getLogger(__name__)
4343

4444

45+
def _parse_bool_url_param(value: Optional[str], default: bool) -> bool:
46+
if value is None:
47+
return default
48+
return value.lower() not in ("0", "false", "no", "off")
49+
50+
4551
class DatabricksDialect(default.DefaultDialect):
4652
"""This dialect implements only those methods required to pass our e2e tests"""
4753

@@ -65,6 +71,7 @@ class DatabricksDialect(default.DefaultDialect):
6571
supports_server_side_cursors: bool = False
6672
supports_sequences: bool = False
6773
supports_native_boolean: bool = True
74+
enable_multirow_insert_casts: bool = True
6875

6976
colspecs = {
7077
sqlalchemy.types.DateTime: dialect_type_impl.TIMESTAMP_NTZ,
@@ -117,6 +124,9 @@ def create_connect_args(self, url):
117124

118125
self.schema = kwargs["schema"]
119126
self.catalog = kwargs["catalog"]
127+
self.enable_multirow_insert_casts = _parse_bool_url_param(
128+
url.query.get("enable_multirow_insert_casts"), True
129+
)
120130

121131
self._force_paramstyle_to_native_mode()
122132

tests/test_local/test_ddl.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,22 @@ def test_multi_values_casts_mixed_type_column(self):
458458
assert "CAST(:`name_m1` AS STRING)" not in sql
459459
assert "CAST(:`name_m2` AS STRING)" not in sql
460460

461+
def test_multi_value_casts_can_be_disabled_by_url_param(self):
462+
engine = create_engine(
463+
"databricks://token:****@****"
464+
"?http_path=****&catalog=****&schema=****"
465+
"&enable_multirow_insert_casts=false"
466+
)
467+
metadata = MetaData()
468+
table = Table("t", metadata, Column("value", String()))
469+
stmt = insert(table).values([{"value": 1}, {"value": 0}, {"value": "NE"}])
470+
471+
sql = str(stmt.compile(bind=engine))
472+
assert "CAST(:`value_m0` AS STRING)" not in sql
473+
assert ":`value_m0`" in sql
474+
assert ":`value_m1`" in sql
475+
assert ":`value_m2`" in sql
476+
461477
def test_homogeneous_multi_values_are_not_cast(self):
462478
metadata = MetaData()
463479
table = Table("t", metadata, Column("value", String()))

0 commit comments

Comments
 (0)