Skip to content

Commit f848623

Browse files
kushalbakshiclaude
andcommitted
fix: backend-agnostic improvements for multi-backend SQL support
Improves DataJoint's adapter abstraction to properly support non-MySQL backends. These changes benefit PostgreSQL and enable third-party adapter registration. Changes: - Add make_full_table_name() to adapter ABC — centralizes table name construction (was hardcoded in 5 files) - Add foreign_key_action_clause property — FK referential actions via adapter (was hardcoded in declare.py) - Use fetchone() instead of rowcount for existence checks — DBAPI2 does not guarantee rowcount for non-DML statements (table.py, schemas.py) - Guard transaction methods against empty SQL — supports backends without multi-table transactions (connection.py) - Use adapter.quote_identifier() for job metadata columns — fixes PostgreSQL which uses double-quotes (autopopulate.py) - Use adapter.split_full_table_name() for name parsing — backend- agnostic instead of manual split (declare.py) - Route lineage table check through adapter — catalog-qualified queries for backends with namespaced information_schema (lineage.py) - Add "bytes"/"binary" to blob type detection list — supports backends that use BINARY instead of longblob (codecs.py) - Extend backend Literal to accept third-party adapter names (settings.py) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 34acbbe commit f848623

File tree

10 files changed

+67
-37
lines changed

10 files changed

+67
-37
lines changed

src/datajoint/adapters/base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,39 @@ def parameter_placeholder(self) -> str:
238238
"""
239239
...
240240

241+
def make_full_table_name(self, database: str, table_name: str) -> str:
242+
"""
243+
Construct a fully-qualified table name for this backend.
244+
245+
Default implementation produces a two-part name (``schema.table``).
246+
Backends that require additional namespace levels (e.g., Databricks
247+
``catalog.schema.table``) should override this method.
248+
249+
Parameters
250+
----------
251+
database : str
252+
Schema/database name.
253+
table_name : str
254+
Table name (including tier prefix).
255+
256+
Returns
257+
-------
258+
str
259+
Fully-qualified, quoted table name.
260+
"""
261+
return f"{self.quote_identifier(database)}.{self.quote_identifier(table_name)}"
262+
263+
@property
264+
def foreign_key_action_clause(self) -> str:
265+
"""
266+
Referential action clause appended to FOREIGN KEY declarations.
267+
268+
Default: ``ON UPDATE CASCADE ON DELETE RESTRICT`` (MySQL/PostgreSQL).
269+
Backends that don't support referential actions (e.g., Databricks)
270+
should override to return ``""``.
271+
"""
272+
return " ON UPDATE CASCADE ON DELETE RESTRICT"
273+
241274
# =========================================================================
242275
# Type Mapping
243276
# =========================================================================

src/datajoint/autopopulate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,9 +776,10 @@ def _update_job_metadata(self, key, start_time, duration, version):
776776
from .condition import make_condition
777777

778778
pk_condition = make_condition(self, key, set())
779+
q = self.connection.adapter.quote_identifier
779780
self.connection.query(
780781
f"UPDATE {self.full_table_name} SET "
781-
"`_job_start_time`=%s, `_job_duration`=%s, `_job_version`=%s "
782+
f"{q('_job_start_time')}=%s, {q('_job_duration')}=%s, {q('_job_version')}=%s "
782783
f"WHERE {pk_condition}",
783784
args=(start_time, duration, version[:64] if version else ""),
784785
)

src/datajoint/codecs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def decode_attribute(attr, data, squeeze: bool = False, connection=None):
557557
# psycopg2 auto-deserializes JSON to dict/list; only parse strings
558558
if isinstance(data, str):
559559
data = json.loads(data)
560-
elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob"):
560+
elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob", "bytes", "binary"):
561561
pass # Blob data is already bytes
562562
elif final_dtype.lower() == "binary(16)":
563563
data = uuid_module.UUID(bytes=data)

src/datajoint/connection.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,19 +486,25 @@ def start_transaction(self) -> None:
486486
"""
487487
if self.in_transaction:
488488
raise errors.DataJointError("Nested connections are not supported.")
489-
self.query(self.adapter.start_transaction_sql())
489+
sql = self.adapter.start_transaction_sql()
490+
if sql:
491+
self.query(sql)
490492
self._in_transaction = True
491493
logger.debug("Transaction started")
492494

493495
def cancel_transaction(self) -> None:
494496
"""Cancel the current transaction and roll back all changes."""
495-
self.query(self.adapter.rollback_sql())
497+
sql = self.adapter.rollback_sql()
498+
if sql:
499+
self.query(sql)
496500
self._in_transaction = False
497501
logger.debug("Transaction cancelled. Rolling back ...")
498502

499503
def commit_transaction(self) -> None:
500504
"""Commit all changes and close the transaction."""
501-
self.query(self.adapter.commit_sql())
505+
sql = self.adapter.commit_sql()
506+
if sql:
507+
self.query(sql)
502508
self._in_transaction = False
503509
logger.debug("Transaction committed and closed.")
504510

src/datajoint/declare.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,10 @@ def compile_foreign_key(
296296
parent_full_name = ref.support[0]
297297
# Parse as database.table using the adapter's quoting convention
298298
parts = adapter.split_full_table_name(parent_full_name)
299-
ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}"
299+
ref_table_name = adapter.make_full_table_name(parts[0], parts[1])
300300

301301
foreign_key_sql.append(
302-
f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}) ON UPDATE CASCADE ON DELETE RESTRICT"
302+
f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}){adapter.foreign_key_action_clause}"
303303
)
304304

305305
# declare unique index
@@ -432,16 +432,8 @@ def declare(
432432
DataJointError
433433
If table name exceeds max length or has no primary key.
434434
"""
435-
# Parse table name without assuming quote character
436-
# Extract schema.table from quoted name using adapter
437-
quote_char = adapter.quote_identifier("x")[0] # Get quote char from adapter
438-
parts = full_table_name.split(".")
439-
if len(parts) == 2:
440-
schema_name = parts[0].strip(quote_char)
441-
table_name = parts[1].strip(quote_char)
442-
else:
443-
schema_name = None
444-
table_name = parts[0].strip(quote_char)
435+
# Parse table name using adapter (handles 2-part and 3-part names)
436+
schema_name, table_name = adapter.split_full_table_name(full_table_name)
445437

446438
if len(table_name) > MAX_TABLE_NAME_LENGTH:
447439
raise DataJointError(
@@ -924,7 +916,7 @@ def compile_attribute(
924916
# Check for invalid default values on blob types (after type substitution)
925917
# Note: blob → longblob, so check for NATIVE_BLOB or longblob result
926918
final_type = match["type"].lower()
927-
if ("blob" in final_type) and match["default"] not in {"DEFAULT NULL", "NOT NULL"}:
919+
if ("blob" in final_type or final_type == "binary") and match["default"] not in {"DEFAULT NULL", "NOT NULL"}:
928920
raise DataJointError("The default value for blob attributes can only be NULL in:\n{line}".format(line=line))
929921

930922
# Use adapter to format column definition

src/datajoint/lineage.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,13 @@ def lineage_table_exists(connection, database):
7979
bool
8080
True if the table exists, False otherwise.
8181
"""
82-
result = connection.query(
83-
"""
84-
SELECT COUNT(*) FROM information_schema.tables
85-
WHERE table_schema = %s AND table_name = '~lineage'
86-
""",
87-
args=(database,),
88-
).fetchone()
89-
return result[0] > 0
82+
try:
83+
result = connection.query(
84+
connection.adapter.get_table_info_sql(database, "~lineage")
85+
).fetchone()
86+
return result is not None
87+
except Exception:
88+
return False
9089

9190

9291
def get_lineage(connection, database, table_name, attribute_name):

src/datajoint/schemas.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def make_classes(self, into: dict[str, Any] | None = None) -> None:
344344
tables = [
345345
row[0]
346346
for row in self.connection.query(self.connection.adapter.list_tables_sql(self.database))
347-
if lookup_class_name("`{db}`.`{tab}`".format(db=self.database, tab=row[0]), into, 0) is None
347+
if lookup_class_name(self.connection.adapter.make_full_table_name(self.database, row[0]), into, 0) is None
348348
]
349349
master_classes = (Lookup, Manual, Imported, Computed)
350350
part_tables = []
@@ -421,7 +421,8 @@ def exists(self) -> bool:
421421
"""
422422
if self.database is None:
423423
raise DataJointError("Schema must be activated first.")
424-
return bool(self.connection.query(self.connection.adapter.schema_exists_sql(self.database)).rowcount)
424+
result = self.connection.query(self.connection.adapter.schema_exists_sql(self.database))
425+
return result.fetchone() is not None
425426

426427
@property
427428
def lineage_table_exists(self) -> bool:
@@ -502,7 +503,7 @@ def jobs(self) -> list[Job]:
502503
# Iterate over auto-populated tables and check if their job table exists
503504
for table_name in self.list_tables():
504505
adapter = self.connection.adapter
505-
full_name = f"{adapter.quote_identifier(self.database)}." f"{adapter.quote_identifier(table_name)}"
506+
full_name = adapter.make_full_table_name(self.database, table_name)
506507
table = FreeTable(self.connection, full_name)
507508
tier = _get_tier(table.full_table_name)
508509
if tier in (Computed, Imported):
@@ -603,7 +604,7 @@ def get_table(self, name: str) -> FreeTable:
603604
raise DataJointError(f"Table `{name}` does not exist in schema `{self.database}`.")
604605

605606
adapter = self.connection.adapter
606-
full_name = f"{adapter.quote_identifier(self.database)}.{adapter.quote_identifier(table_name)}"
607+
full_name = adapter.make_full_table_name(self.database, table_name)
607608
return FreeTable(self.connection, full_name)
608609

609610
def __getitem__(self, name: str) -> FreeTable:

src/datajoint/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class DatabaseSettings(BaseSettings):
190190
host: str = Field(default="localhost", validation_alias="DJ_HOST")
191191
user: str | None = Field(default=None, validation_alias="DJ_USER")
192192
password: SecretStr | None = Field(default=None, validation_alias="DJ_PASS")
193-
backend: Literal["mysql", "postgresql"] = Field(
193+
backend: Literal["mysql", "postgresql", "databricks"] = Field(
194194
default="mysql",
195195
validation_alias="DJ_BACKEND",
196196
description="Database backend: 'mysql' or 'postgresql'",

src/datajoint/table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def is_declared(self):
457457
True if the table is declared in the schema.
458458
"""
459459
query = self.connection.adapter.get_table_info_sql(self.database, self.table_name)
460-
return self.connection.query(query).rowcount > 0
460+
return self.connection.query(query).fetchone() is not None
461461

462462
@property
463463
def full_table_name(self):
@@ -474,7 +474,7 @@ def full_table_name(self):
474474
f"Class {self.__class__.__name__} is not associated with a schema. "
475475
"Apply a schema decorator or use schema() to bind it."
476476
)
477-
return f"{self.adapter.quote_identifier(self.database)}.{self.adapter.quote_identifier(self.table_name)}"
477+
return self.adapter.make_full_table_name(self.database, self.table_name)
478478

479479
@property
480480
def adapter(self):

src/datajoint/user_tables.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def full_table_name(cls):
106106
"""The fully qualified table name (quoted per backend)."""
107107
if cls.database is None:
108108
return None
109-
adapter = cls._connection.adapter
110-
return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}"
109+
return cls._connection.adapter.make_full_table_name(cls.database, cls.table_name)
111110

112111

113112
class UserTable(Table, metaclass=TableMeta):
@@ -186,8 +185,7 @@ def full_table_name(cls):
186185
"""The fully qualified table name (quoted per backend)."""
187186
if cls.database is None or cls.table_name is None:
188187
return None
189-
adapter = cls._connection.adapter
190-
return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}"
188+
return cls._connection.adapter.make_full_table_name(cls.database, cls.table_name)
191189

192190
@property
193191
def master(cls):

0 commit comments

Comments
 (0)