|
1 | 1 | import re |
| 2 | +from datetime import date, datetime, time |
| 3 | +from numbers import Number |
| 4 | +from uuid import UUID |
2 | 5 | from sqlalchemy.sql import compiler, sqltypes |
3 | 6 | import logging |
4 | 7 |
|
@@ -165,6 +168,138 @@ def bindparam_string(self, name, **kw): |
165 | 168 | return self._BIND_TEMPLATE % {"name": name.replace("`", "``")} |
166 | 169 | return super().bindparam_string(name, **kw) |
167 | 170 |
|
| 171 | + @staticmethod |
| 172 | + def _split_multivalue_bind_name(bind_name): |
| 173 | + """Split SQLAlchemy's ``<col>_m<idx>`` bind names into (column, idx).""" |
| 174 | + match = re.match(r"^(?P<col>.+)_m(?P<idx>\d+)$", bind_name) |
| 175 | + if not match: |
| 176 | + return None |
| 177 | + return match.group("col"), int(match.group("idx")) |
| 178 | + |
| 179 | + @staticmethod |
| 180 | + def _value_family(value): |
| 181 | + """Return scalar value family; ``None`` means non-scalar/unsupported.""" |
| 182 | + if value is None: |
| 183 | + return "null" |
| 184 | + if isinstance(value, bool): |
| 185 | + return "bool" |
| 186 | + if isinstance(value, Number): |
| 187 | + return "number" |
| 188 | + if isinstance(value, str): |
| 189 | + return "string" |
| 190 | + if isinstance(value, (bytes, bytearray, memoryview)): |
| 191 | + return "binary" |
| 192 | + if isinstance(value, (date, time, datetime)): |
| 193 | + return "temporal" |
| 194 | + if isinstance(value, UUID): |
| 195 | + return "uuid" |
| 196 | + return None |
| 197 | + |
| 198 | + @staticmethod |
| 199 | + def _has_custom_bind_expression(type_engine): |
| 200 | + """True if the type (or its impl) customizes bind-expression rendering.""" |
| 201 | + type_cls = type(type_engine) |
| 202 | + if ( |
| 203 | + getattr(type_cls, "bind_expression", None) |
| 204 | + is not sqltypes.TypeEngine.bind_expression |
| 205 | + ): |
| 206 | + return True |
| 207 | + |
| 208 | + impl = getattr(type_engine, "impl", None) |
| 209 | + if impl is not None: |
| 210 | + impl_cls = type(impl) |
| 211 | + if ( |
| 212 | + getattr(impl_cls, "bind_expression", None) |
| 213 | + is not sqltypes.TypeEngine.bind_expression |
| 214 | + ): |
| 215 | + return True |
| 216 | + return False |
| 217 | + |
| 218 | + def _build_multi_value_cast_plan(self, insert_stmt): |
| 219 | + """Return {bind_name: cast_sql_type} for multi-row VALUES insert binds. |
| 220 | +
|
| 221 | + Cast only *mixed scalar* multi-row bind groups whose SQLAlchemy target |
| 222 | + type compiles to STRING. This avoids silent data loss for non-string |
| 223 | + target columns and avoids breaking complex/custom bind types (e.g. |
| 224 | + ARRAY/MAP/VARIANT), while still fixing Spark inline-table |
| 225 | + incompatibility for object columns that mix primitive families into a |
| 226 | + string-like target column. |
| 227 | + """ |
| 228 | + if not self.dialect.enable_multirow_insert_casts: |
| 229 | + return {} |
| 230 | + |
| 231 | + if not getattr(insert_stmt, "_multi_values", None): |
| 232 | + return {} |
| 233 | + |
| 234 | + grouped_binds = {} |
| 235 | + for bind_name, bind_param in self.binds.items(): |
| 236 | + split = self._split_multivalue_bind_name(bind_name) |
| 237 | + if split is None: |
| 238 | + continue |
| 239 | + column_name, _ = split |
| 240 | + grouped_binds.setdefault(column_name, []).append((bind_name, bind_param)) |
| 241 | + |
| 242 | + cast_plan = {} |
| 243 | + for bind_entries in grouped_binds.values(): |
| 244 | + families = set() |
| 245 | + has_non_scalar = False |
| 246 | + has_custom_bind_expression = False |
| 247 | + |
| 248 | + for _, bind_param in bind_entries: |
| 249 | + value_family = self._value_family(getattr(bind_param, "value", None)) |
| 250 | + if value_family is None: |
| 251 | + has_non_scalar = True |
| 252 | + break |
| 253 | + if value_family != "null": |
| 254 | + families.add(value_family) |
| 255 | + |
| 256 | + type_engine = getattr(bind_param, "type", None) |
| 257 | + if type_engine is not None and self._has_custom_bind_expression( |
| 258 | + type_engine |
| 259 | + ): |
| 260 | + has_custom_bind_expression = True |
| 261 | + |
| 262 | + if has_non_scalar or has_custom_bind_expression or len(families) <= 1: |
| 263 | + continue |
| 264 | + |
| 265 | + bind_targets = [] |
| 266 | + for bind_name, bind_param in bind_entries: |
| 267 | + type_engine = getattr(bind_param, "type", None) |
| 268 | + if type_engine is None or isinstance(type_engine, sqltypes.NullType): |
| 269 | + continue |
| 270 | + |
| 271 | + dialect_type = type_engine._unwrapped_dialect_impl(self.dialect) |
| 272 | + target_type = self.dialect.type_compiler_instance.process( |
| 273 | + dialect_type, identifier_preparer=self.preparer |
| 274 | + ) |
| 275 | + bind_targets.append((bind_name, target_type)) |
| 276 | + |
| 277 | + if not bind_targets or any( |
| 278 | + target_type.upper() != "STRING" for _, target_type in bind_targets |
| 279 | + ): |
| 280 | + continue |
| 281 | + |
| 282 | + for bind_name, target_type in bind_targets: |
| 283 | + cast_plan[bind_name] = target_type |
| 284 | + |
| 285 | + return cast_plan |
| 286 | + |
| 287 | + def _apply_multi_value_casts(self, sql_text, insert_stmt): |
| 288 | + """Wrap selected ``:`name``` markers with ``CAST(... AS <type>)``.""" |
| 289 | + cast_plan = self._build_multi_value_cast_plan(insert_stmt) |
| 290 | + if not cast_plan: |
| 291 | + return sql_text |
| 292 | + |
| 293 | + rendered = sql_text |
| 294 | + for bind_name, target_type in cast_plan.items(): |
| 295 | + marker = self._BIND_TEMPLATE % {"name": bind_name.replace("`", "``")} |
| 296 | + rendered = rendered.replace(marker, f"CAST({marker} AS {target_type})") |
| 297 | + return rendered |
| 298 | + |
| 299 | + def visit_insert(self, insert_stmt, **kw): |
| 300 | + sql_text = super().visit_insert(insert_stmt, **kw) |
| 301 | + return self._apply_multi_value_casts(sql_text, insert_stmt) |
| 302 | + |
168 | 303 | def limit_clause(self, select, **kw): |
169 | 304 | """Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1, |
170 | 305 | since Databricks SQL doesn't support the latter. |
|
0 commit comments