Skip to content

Commit 35634c4

Browse files
committed
Bind UUIDs in canonical hyphenated form (fixes #50)
SQLAlchemy's default Uuid type binds parameters as the 32-char hex form without dashes on backends that lack native UUID support, so equality against canonically-stored UUIDs returns 0 rows. Add a DatabricksUUID subclass that always emits the canonical 8-4-4-4-12 form on the wire, validates input through uuid.UUID (rejecting non-UUID strings instead of injecting them into SQL), and reconstructs UUIDs via .int so subclassed __str__ cannot escape the literal quotes. Wire it via colspecs so plain Column(Uuid) picks it up automatically. Signed-off-by: Sreekanth Vadigi <sreekanth.vadigi@databricks.com>
1 parent 106e087 commit 35634c4

3 files changed

Lines changed: 202 additions & 1 deletion

File tree

src/databricks/sqlalchemy/_types.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime, time, timezone
22
from itertools import product
33
from typing import Any, Union, Optional
4+
from uuid import UUID
45

56
import sqlalchemy
67
from sqlalchemy.engine.interfaces import Dialect
@@ -315,6 +316,60 @@ def process(value):
315316
return process
316317

317318

319+
class DatabricksUUID(sqlalchemy.types.Uuid):
320+
"""Bind UUIDs in their canonical 8-4-4-4-12 hyphenated form.
321+
322+
Databricks has no native UUID type, so SQLAlchemy's default ``Uuid``
323+
bind/literal processors render the 32-character hex form without dashes
324+
(e.g. ``1daa91d78d35468486d63fa89042c1f4``). That breaks equality against
325+
UUIDs stored as canonical strings in Databricks. We coerce every input
326+
through ``uuid.UUID`` so the wire value is always the canonical hyphenated
327+
form regardless of whether the caller passed a ``UUID``, a hyphenated
328+
string, or a dash-less hex string. The ``UUID(...)`` round-trip also
329+
validates the input — any non-UUID string raises ``ValueError`` instead of
330+
being silently injected into SQL, which is critical for ``literal_binds``
331+
rendering safety.
332+
333+
With the default ``as_uuid=True``, the inherited ``result_processor``
334+
parses both hyphenated and dash-less hex forms back into a ``UUID``
335+
object, so reads of legacy hex-stored rows continue to work. With
336+
``as_uuid=False`` the result is returned as the raw column string —
337+
callers who mix legacy hex-stored rows with the canonical form should
338+
normalize on read themselves.
339+
"""
340+
341+
cache_ok = True
342+
343+
@staticmethod
344+
def _canonical(value):
345+
"""Return the canonical hyphenated string for ``value``.
346+
347+
For UUID instances we rebuild a stdlib ``UUID`` from ``.int`` so a
348+
subclass cannot smuggle an arbitrary string through an overridden
349+
``__str__`` — the canonical hyphenated form of ``value.int`` goes to
350+
the wire, so no attacker-controlled string can escape the quotes.
351+
"""
352+
if isinstance(value, UUID):
353+
return str(UUID(int=value.int))
354+
return str(UUID(str(value)))
355+
356+
def bind_processor(self, dialect):
357+
def process(value):
358+
if value is None:
359+
return None
360+
return self._canonical(value)
361+
362+
return process
363+
364+
def literal_processor(self, dialect):
365+
def process(value):
366+
if value is None:
367+
return "NULL"
368+
return "'%s'" % self._canonical(value)
369+
370+
return process
371+
372+
318373
class TINYINT(sqlalchemy.types.TypeDecorator):
319374
"""Represents 1-byte signed integers
320375

src/databricks/sqlalchemy/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class DatabricksDialect(default.DefaultDialect):
7070
sqlalchemy.types.DateTime: dialect_type_impl.TIMESTAMP_NTZ,
7171
sqlalchemy.types.Time: dialect_type_impl.DatabricksTimeType,
7272
sqlalchemy.types.String: dialect_type_impl.DatabricksStringType,
73+
sqlalchemy.types.Uuid: dialect_type_impl.DatabricksUUID,
7374
}
7475

7576
# SQLAlchemy requires that a table with no primary key

tests/test_local/test_types.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import enum
2+
from uuid import UUID
23

34
import pytest
45
import sqlalchemy
6+
from sqlalchemy import Column, MetaData, Table, select
57

68
from databricks.sqlalchemy.base import DatabricksDialect
7-
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ, DatabricksVariant
9+
from databricks.sqlalchemy._types import (
10+
DatabricksUUID,
11+
DatabricksVariant,
12+
TINYINT,
13+
TIMESTAMP,
14+
TIMESTAMP_NTZ,
15+
)
816

917

1018
class DatabricksDataType(enum.Enum):
@@ -161,3 +169,140 @@ def test_array_string_renders_as_array_of_string(self):
161169
self._assert_compiled_value_explicit(
162170
sqlalchemy.types.ARRAY(sqlalchemy.types.String), "ARRAY<STRING>"
163171
)
172+
173+
174+
class TestDatabricksUUID:
175+
"""Regression coverage for github.com/databricks/databricks-sqlalchemy/issues/50.
176+
177+
SQLAlchemy's default Uuid renders the 32-char hex form (no dashes) on backends
178+
without a native UUID type, which breaks equality against UUIDs stored as
179+
canonical 8-4-4-4-12 strings in Databricks.
180+
"""
181+
182+
dialect = DatabricksDialect()
183+
HYPHENATED = "1daa91d7-8d35-4684-86d6-3fa89042c1f4"
184+
HEX = "1daa91d78d35468486d63fa89042c1f4"
185+
sample = UUID(HYPHENATED)
186+
187+
def test_bind_processor_preserves_hyphenated_form(self):
188+
process = DatabricksUUID().bind_processor(self.dialect)
189+
assert process(self.sample) == self.HYPHENATED
190+
191+
def test_bind_processor_handles_none(self):
192+
process = DatabricksUUID().bind_processor(self.dialect)
193+
assert process(None) is None
194+
195+
def test_literal_processor_renders_hyphenated_form(self):
196+
process = DatabricksUUID().literal_processor(self.dialect)
197+
assert process(self.sample) == "'%s'" % self.HYPHENATED
198+
199+
def test_literal_processor_handles_none(self):
200+
process = DatabricksUUID().literal_processor(self.dialect)
201+
assert process(None) == "NULL"
202+
203+
def test_result_processor_accepts_both_forms(self):
204+
process = DatabricksUUID().result_processor(self.dialect, None)
205+
assert process(self.HYPHENATED) == self.sample
206+
assert process(self.HEX) == self.sample
207+
assert process(None) is None
208+
209+
def test_dialect_routes_uuid_to_databricks_uuid(self):
210+
"""The colspecs entry is what makes a plain ``Uuid`` column use our impl."""
211+
assert self.dialect.colspecs[sqlalchemy.types.Uuid] is DatabricksUUID
212+
213+
def test_uuid_where_clause_renders_with_dashes(self):
214+
meta = MetaData()
215+
users = Table(
216+
"users", meta, Column("id", sqlalchemy.types.Uuid, primary_key=True)
217+
)
218+
stmt = select(users).where(users.c.id == self.sample)
219+
220+
literal_sql = str(
221+
stmt.compile(dialect=self.dialect, compile_kwargs={"literal_binds": True})
222+
)
223+
assert "'%s'" % self.HYPHENATED in literal_sql
224+
assert self.HEX not in literal_sql
225+
226+
def test_uuid_bound_param_wire_value_has_dashes(self):
227+
meta = MetaData()
228+
users = Table(
229+
"users", meta, Column("id", sqlalchemy.types.Uuid, primary_key=True)
230+
)
231+
stmt = select(users).where(users.c.id == self.sample)
232+
compiled = stmt.compile(dialect=self.dialect)
233+
234+
raw = compiled.construct_params()
235+
processed = {
236+
key: (
237+
compiled._bind_processors[key](value)
238+
if key in compiled._bind_processors
239+
else value
240+
)
241+
for key, value in raw.items()
242+
}
243+
assert self.HYPHENATED in processed.values()
244+
245+
def test_bind_processor_normalizes_hex_string_to_canonical(self):
246+
"""A bare 32-char hex string (no dashes) must be coerced to canonical form."""
247+
process = DatabricksUUID().bind_processor(self.dialect)
248+
assert process(self.HEX) == self.HYPHENATED
249+
250+
def test_bind_processor_normalizes_hyphenated_string(self):
251+
"""A canonical hyphenated string passes through unchanged."""
252+
process = DatabricksUUID().bind_processor(self.dialect)
253+
assert process(self.HYPHENATED) == self.HYPHENATED
254+
255+
def test_bind_processor_rejects_non_uuid_string(self):
256+
"""Bad input must raise instead of being silently written to the column."""
257+
process = DatabricksUUID().bind_processor(self.dialect)
258+
with pytest.raises(ValueError):
259+
process("not-a-uuid")
260+
261+
def test_literal_processor_rejects_injection_attempt(self):
262+
"""An attacker-controlled string must not be allowed to escape the quotes.
263+
264+
Before the input was normalized through ``UUID(...)``, a string like
265+
``abc' OR '1'='1`` would inject directly into ``WHERE id = 'abc' OR
266+
'1'='1'`` whenever ``literal_binds=True`` was used.
267+
"""
268+
process = DatabricksUUID().literal_processor(self.dialect)
269+
with pytest.raises(ValueError):
270+
process("abc' OR '1'='1")
271+
272+
def test_literal_processor_rejects_injection_via_uuid_subclass(self):
273+
"""A UUID subclass with a malicious ``__str__`` must not bypass escaping.
274+
275+
``_canonical`` reconstructs every UUID through ``UUID(int=value.int)``
276+
so a subclass cannot inject SQL via an overridden string conversion.
277+
"""
278+
279+
class EvilUUID(UUID):
280+
def __str__(self): # type: ignore[override]
281+
return "abc' OR '1'='1"
282+
283+
process = DatabricksUUID().literal_processor(self.dialect)
284+
rendered = process(EvilUUID(self.HYPHENATED))
285+
assert rendered == "'%s'" % self.HYPHENATED
286+
assert "OR" not in rendered
287+
288+
def test_literal_processor_rejects_injection_with_as_uuid_false(self):
289+
"""``Uuid(as_uuid=False)`` shares ``_canonical``; lock in coverage anyway."""
290+
process = DatabricksUUID(as_uuid=False).literal_processor(self.dialect)
291+
with pytest.raises(ValueError):
292+
process("abc' OR '1'='1")
293+
294+
def test_as_uuid_false_round_trip_normalizes_hex_input(self):
295+
"""``Uuid(as_uuid=False)`` users sometimes pass hex-form strings.
296+
297+
Pre-fix, the dialect wrote those through unchanged, so a query against
298+
canonically-stored data returned 0 rows. After the fix, both forms
299+
normalize to canonical on the wire.
300+
"""
301+
type_ = DatabricksUUID(as_uuid=False)
302+
bind = type_.bind_processor(self.dialect)
303+
result = type_.result_processor(self.dialect, None)
304+
305+
assert bind(self.HEX) == self.HYPHENATED
306+
assert bind(self.HYPHENATED) == self.HYPHENATED
307+
assert result(self.HYPHENATED) == self.HYPHENATED
308+
assert result(self.HEX) == self.HYPHENATED

0 commit comments

Comments
 (0)