File tree Expand file tree Collapse file tree
src/databricks/sqlalchemy Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -223,6 +223,9 @@ def _build_multi_value_cast_plan(self, insert_stmt):
223223 Spark inline-table incompatibility for object columns that mix
224224 primitive families (e.g. INT + STRING).
225225 """
226+ if not self .dialect .enable_multirow_insert_casts :
227+ return {}
228+
226229 if not getattr (insert_stmt , "_multi_values" , None ):
227230 return {}
228231
Original file line number Diff line number Diff line change @@ -42,6 +42,12 @@ class DatabricksImpl(DefaultImpl):
4242logger = logging .getLogger (__name__ )
4343
4444
45+ def _parse_bool_url_param (value : Optional [str ], default : bool ) -> bool :
46+ if value is None :
47+ return default
48+ return value .lower () not in ("0" , "false" , "no" , "off" )
49+
50+
4551class DatabricksDialect (default .DefaultDialect ):
4652 """This dialect implements only those methods required to pass our e2e tests"""
4753
@@ -65,6 +71,7 @@ class DatabricksDialect(default.DefaultDialect):
6571 supports_server_side_cursors : bool = False
6672 supports_sequences : bool = False
6773 supports_native_boolean : bool = True
74+ enable_multirow_insert_casts : bool = True
6875
6976 colspecs = {
7077 sqlalchemy .types .DateTime : dialect_type_impl .TIMESTAMP_NTZ ,
@@ -117,6 +124,9 @@ def create_connect_args(self, url):
117124
118125 self .schema = kwargs ["schema" ]
119126 self .catalog = kwargs ["catalog" ]
127+ self .enable_multirow_insert_casts = _parse_bool_url_param (
128+ url .query .get ("enable_multirow_insert_casts" ), True
129+ )
120130
121131 self ._force_paramstyle_to_native_mode ()
122132
Original file line number Diff line number Diff line change @@ -458,6 +458,22 @@ def test_multi_values_casts_mixed_type_column(self):
458458 assert "CAST(:`name_m1` AS STRING)" not in sql
459459 assert "CAST(:`name_m2` AS STRING)" not in sql
460460
461+ def test_multi_value_casts_can_be_disabled_by_url_param (self ):
462+ engine = create_engine (
463+ "databricks://token:****@****"
464+ "?http_path=****&catalog=****&schema=****"
465+ "&enable_multirow_insert_casts=false"
466+ )
467+ metadata = MetaData ()
468+ table = Table ("t" , metadata , Column ("value" , String ()))
469+ stmt = insert (table ).values ([{"value" : 1 }, {"value" : 0 }, {"value" : "NE" }])
470+
471+ sql = str (stmt .compile (bind = engine ))
472+ assert "CAST(:`value_m0` AS STRING)" not in sql
473+ assert ":`value_m0`" in sql
474+ assert ":`value_m1`" in sql
475+ assert ":`value_m2`" in sql
476+
461477 def test_homogeneous_multi_values_are_not_cast (self ):
462478 metadata = MetaData ()
463479 table = Table ("t" , metadata , Column ("value" , String ()))
You can’t perform that action at this time.
0 commit comments