11import re
2- from numbers import Number
32from sqlalchemy .sql import compiler , sqltypes
43import logging
54
@@ -166,21 +165,6 @@ def bindparam_string(self, name, **kw):
166165 return self ._BIND_TEMPLATE % {"name" : name .replace ("`" , "``" )}
167166 return super ().bindparam_string (name , ** kw )
168167
169- @staticmethod
170- def _value_family (value ):
171- """Return a coarse runtime family for adaptive multi-row cast decisions."""
172- if value is None :
173- return "null"
174- if isinstance (value , bool ):
175- return "bool"
176- if isinstance (value , str ):
177- return "string"
178- if isinstance (value , (bytes , bytearray , memoryview )):
179- return "binary"
180- if isinstance (value , Number ):
181- return "number"
182- return "other"
183-
184168 @staticmethod
185169 def _split_multivalue_bind_name (bind_name ):
186170 """Split SQLAlchemy's ``<col>_m<idx>`` bind names into (column, idx)."""
@@ -189,55 +173,37 @@ def _split_multivalue_bind_name(bind_name):
189173 return None
190174 return match .group ("col" ), int (match .group ("idx" ))
191175
192- def _build_adaptive_cast_plan (self ):
193- """Return {bind_name: cast_sql_type} for risky multi-row value groups .
176+ def _build_multi_value_cast_plan (self , insert_stmt ):
177+ """Return {bind_name: cast_sql_type} for multi-row VALUES insert binds .
194178
195- We only target SQLAlchemy-generated multi-row binds (``*_mN``). For
196- each logical column we inspect row values available at compile time and
197- cast only when families are heterogeneous in a way that commonly causes
198- Spark inline- table incompatibility (e.g., number + string) .
179+ This is a deterministic fix for Spark inline-table type reconciliation:
180+ for SQLAlchemy-generated multi- row INSERT binds (``*_mN``), always cast
181+ the marker to the bind's dialect SQL type so each column position in the
182+ VALUES table has an explicit server-side type .
199183 """
200- column_bind_names = {}
201- for bind_name , bind_param in self .binds .items ():
202- split = self ._split_multivalue_bind_name (bind_name )
203- if split is None :
204- continue
205- column_name , _ = split
206- column_bind_names .setdefault (column_name , []).append ((bind_name , bind_param ))
184+ if not getattr (insert_stmt , "_multi_values" , None ):
185+ return {}
207186
208187 cast_plan = {}
209- for bind_entries in column_bind_names .values ():
210- families = set ()
211- for _ , bind_param in bind_entries :
212- value = getattr (bind_param , "value" , None )
213- family = self ._value_family (value )
214- if family != "null" :
215- families .add (family )
216-
217- if len (families ) <= 1 :
188+ for bind_name , bind_param in self .binds .items ():
189+ if self ._split_multivalue_bind_name (bind_name ) is None :
218190 continue
219191
220- # Numeric + numeric is safe for Spark inline tables and does not
221- # need explicit casting.
222- if families == {"number" }:
192+ type_engine = getattr (bind_param , "type" , None )
193+ if type_engine is None or isinstance (type_engine , sqltypes .NullType ):
223194 continue
224195
225- for bind_name , bind_param in bind_entries :
226- type_engine = getattr (bind_param , "type" , None )
227- if type_engine is None or isinstance (type_engine , sqltypes .NullType ):
228- continue
229-
230- dialect_type = type_engine ._unwrapped_dialect_impl (self .dialect )
231- target_type = self .dialect .type_compiler_instance .process (
232- dialect_type , identifier_preparer = self .preparer
233- )
234- cast_plan [bind_name ] = target_type
196+ dialect_type = type_engine ._unwrapped_dialect_impl (self .dialect )
197+ target_type = self .dialect .type_compiler_instance .process (
198+ dialect_type , identifier_preparer = self .preparer
199+ )
200+ cast_plan [bind_name ] = target_type
235201
236202 return cast_plan
237203
238- def _apply_adaptive_multi_value_casts (self , sql_text ):
204+ def _apply_multi_value_casts (self , sql_text , insert_stmt ):
239205 """Wrap selected ``:`name``` markers with ``CAST(... AS <type>)``."""
240- cast_plan = self ._build_adaptive_cast_plan ( )
206+ cast_plan = self ._build_multi_value_cast_plan ( insert_stmt )
241207 if not cast_plan :
242208 return sql_text
243209
@@ -249,7 +215,7 @@ def _apply_adaptive_multi_value_casts(self, sql_text):
249215
250216 def visit_insert (self , insert_stmt , ** kw ):
251217 sql_text = super ().visit_insert (insert_stmt , ** kw )
252- return self ._apply_adaptive_multi_value_casts (sql_text )
218+ return self ._apply_multi_value_casts (sql_text , insert_stmt )
253219
254220 def limit_clause (self , select , ** kw ):
255221 """Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1,
0 commit comments