Skip to content

Commit 5bae415

Browse files
dimitri-yatsenkokushalbakshiclaude
committed
refactor: backend-agnostic abstractions for multi-backend support
Centralize patterns that were duplicated or hardcoded for MySQL: - Add make_full_table_name() to adapter ABC — consolidates quoted name construction from 7 call sites into one overridable method. Backends with additional namespace levels can override. - Add foreign_key_action_clause property to adapter ABC — FK referential actions via adapter instead of hardcoded in declare.py. Backends without referential action support can override. - Use adapter.split_full_table_name() in declare() — replaces fragile manual quote-char detection. - Guard transaction methods against empty SQL — supports backends without multi-table transaction semantics. - Add "bytes"/"binary" to blob type detection — supports backends that use BINARY instead of longblob. - Route lineage table check through adapter.get_table_info_sql() — replaces hardcoded information_schema query. Co-Authored-By: Kushal Bakshi <kushal.bakshi@datajoint.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b563047 commit 5bae415

File tree

9 files changed

+939
-242
lines changed

9 files changed

+939
-242
lines changed

pixi.lock

Lines changed: 880 additions & 203 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/datajoint/adapters/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,38 @@ def parameter_placeholder(self) -> str:
238238
"""
239239
...
240240

241+
def make_full_table_name(self, database: str, table_name: str) -> str:
242+
"""
243+
Construct a fully-qualified table name for this backend.
244+
245+
Default implementation produces a two-part name (``schema.table``).
246+
Backends that require additional namespace levels can override.
247+
248+
Parameters
249+
----------
250+
database : str
251+
Schema/database name.
252+
table_name : str
253+
Table name (including tier prefix).
254+
255+
Returns
256+
-------
257+
str
258+
Fully-qualified, quoted table name.
259+
"""
260+
return f"{self.quote_identifier(database)}.{self.quote_identifier(table_name)}"
261+
262+
@property
263+
def foreign_key_action_clause(self) -> str:
264+
"""
265+
Referential action clause appended to FOREIGN KEY declarations.
266+
267+
Default: ``ON UPDATE CASCADE ON DELETE RESTRICT`` (MySQL/PostgreSQL).
268+
Backends that don't support referential actions can override to
269+
return ``""``.
270+
"""
271+
return " ON UPDATE CASCADE ON DELETE RESTRICT"
272+
241273
# =========================================================================
242274
# Type Mapping
243275
# =========================================================================

src/datajoint/codecs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def decode_attribute(attr, data, squeeze: bool = False, connection=None):
557557
# psycopg2 auto-deserializes JSON to dict/list; only parse strings
558558
if isinstance(data, str):
559559
data = json.loads(data)
560-
elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob"):
560+
elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob", "bytes", "binary"):
561561
pass # Blob data is already bytes
562562
elif final_dtype.lower() == "binary(16)":
563563
data = uuid_module.UUID(bytes=data)

src/datajoint/connection.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,19 +486,25 @@ def start_transaction(self) -> None:
486486
"""
487487
if self.in_transaction:
488488
raise errors.DataJointError("Nested connections are not supported.")
489-
self.query(self.adapter.start_transaction_sql())
489+
sql = self.adapter.start_transaction_sql()
490+
if sql:
491+
self.query(sql)
490492
self._in_transaction = True
491493
logger.debug("Transaction started")
492494

493495
def cancel_transaction(self) -> None:
494496
"""Cancel the current transaction and roll back all changes."""
495-
self.query(self.adapter.rollback_sql())
497+
sql = self.adapter.rollback_sql()
498+
if sql:
499+
self.query(sql)
496500
self._in_transaction = False
497501
logger.debug("Transaction cancelled. Rolling back ...")
498502

499503
def commit_transaction(self) -> None:
500504
"""Commit all changes and close the transaction."""
501-
self.query(self.adapter.commit_sql())
505+
sql = self.adapter.commit_sql()
506+
if sql:
507+
self.query(sql)
502508
self._in_transaction = False
503509
logger.debug("Transaction committed and closed.")
504510

src/datajoint/declare.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,10 @@ def compile_foreign_key(
296296
parent_full_name = ref.support[0]
297297
# Parse as database.table using the adapter's quoting convention
298298
parts = adapter.split_full_table_name(parent_full_name)
299-
ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}"
299+
ref_table_name = adapter.make_full_table_name(parts[0], parts[1])
300300

301301
foreign_key_sql.append(
302-
f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}) ON UPDATE CASCADE ON DELETE RESTRICT"
302+
f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}){adapter.foreign_key_action_clause}"
303303
)
304304

305305
# declare unique index
@@ -432,16 +432,8 @@ def declare(
432432
DataJointError
433433
If table name exceeds max length or has no primary key.
434434
"""
435-
# Parse table name without assuming quote character
436-
# Extract schema.table from quoted name using adapter
437-
quote_char = adapter.quote_identifier("x")[0] # Get quote char from adapter
438-
parts = full_table_name.split(".")
439-
if len(parts) == 2:
440-
schema_name = parts[0].strip(quote_char)
441-
table_name = parts[1].strip(quote_char)
442-
else:
443-
schema_name = None
444-
table_name = parts[0].strip(quote_char)
435+
# Parse table name using adapter (handles backend-specific quoting)
436+
schema_name, table_name = adapter.split_full_table_name(full_table_name)
445437

446438
if len(table_name) > MAX_TABLE_NAME_LENGTH:
447439
raise DataJointError(
@@ -924,7 +916,7 @@ def compile_attribute(
924916
# Check for invalid default values on blob types (after type substitution)
925917
# Note: blob → longblob, so check for NATIVE_BLOB or longblob result
926918
final_type = match["type"].lower()
927-
if ("blob" in final_type) and match["default"] not in {"DEFAULT NULL", "NOT NULL"}:
919+
if ("blob" in final_type or final_type == "binary") and match["default"] not in {"DEFAULT NULL", "NOT NULL"}:
928920
raise DataJointError("The default value for blob attributes can only be NULL in:\n{line}".format(line=line))
929921

930922
# Use adapter to format column definition

src/datajoint/lineage.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,12 @@ def lineage_table_exists(connection, database):
7979
bool
8080
True if the table exists, False otherwise.
8181
"""
82-
result = connection.query(
83-
"""
84-
SELECT COUNT(*) FROM information_schema.tables
85-
WHERE table_schema = %s AND table_name = '~lineage'
86-
""",
87-
args=(database,),
88-
).fetchone()
89-
return result[0] > 0
82+
try:
83+
result = connection.query(connection.adapter.get_table_info_sql(database, "~lineage")).fetchone()
84+
return result is not None
85+
except Exception:
86+
# Schema or catalog query may fail on some backends
87+
return False
9088

9189

9290
def get_lineage(connection, database, table_name, attribute_name):

src/datajoint/schemas.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,7 @@ def make_classes(self, into: dict[str, Any] | None = None) -> None:
345345
tables = [
346346
row[0]
347347
for row in self.connection.query(adapter.list_tables_sql(self.database))
348-
if lookup_class_name(
349-
f"{adapter.quote_identifier(self.database)}.{adapter.quote_identifier(row[0])}",
350-
into,
351-
0,
352-
)
353-
is None
348+
if lookup_class_name(adapter.make_full_table_name(self.database, row[0]), into, 0) is None
354349
]
355350
master_classes = (Lookup, Manual, Imported, Computed)
356351
part_tables = []
@@ -508,7 +503,7 @@ def jobs(self) -> list[Job]:
508503
# Iterate over auto-populated tables and check if their job table exists
509504
for table_name in self.list_tables():
510505
adapter = self.connection.adapter
511-
full_name = f"{adapter.quote_identifier(self.database)}." f"{adapter.quote_identifier(table_name)}"
506+
full_name = adapter.make_full_table_name(self.database, table_name)
512507
table = FreeTable(self.connection, full_name)
513508
tier = _get_tier(table.full_table_name)
514509
if tier in (Computed, Imported):
@@ -608,8 +603,7 @@ def get_table(self, name: str) -> FreeTable:
608603
if table_name is None:
609604
raise DataJointError(f"Table `{name}` does not exist in schema `{self.database}`.")
610605

611-
adapter = self.connection.adapter
612-
full_name = f"{adapter.quote_identifier(self.database)}.{adapter.quote_identifier(table_name)}"
606+
full_name = self.connection.adapter.make_full_table_name(self.database, table_name)
613607
return FreeTable(self.connection, full_name)
614608

615609
def __getitem__(self, name: str) -> FreeTable:

src/datajoint/table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def full_table_name(self):
474474
f"Class {self.__class__.__name__} is not associated with a schema. "
475475
"Apply a schema decorator or use schema() to bind it."
476476
)
477-
return f"{self.adapter.quote_identifier(self.database)}.{self.adapter.quote_identifier(self.table_name)}"
477+
return self.adapter.make_full_table_name(self.database, self.table_name)
478478

479479
@property
480480
def adapter(self):

src/datajoint/user_tables.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def full_table_name(cls):
106106
"""The fully qualified table name (quoted per backend)."""
107107
if cls.database is None:
108108
return None
109-
adapter = cls._connection.adapter
110-
return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}"
109+
return cls._connection.adapter.make_full_table_name(cls.database, cls.table_name)
111110

112111

113112
class UserTable(Table, metaclass=TableMeta):
@@ -186,8 +185,7 @@ def full_table_name(cls):
186185
"""The fully qualified table name (quoted per backend)."""
187186
if cls.database is None or cls.table_name is None:
188187
return None
189-
adapter = cls._connection.adapter
190-
return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}"
188+
return cls._connection.adapter.make_full_table_name(cls.database, cls.table_name)
191189

192190
@property
193191
def master(cls):

0 commit comments

Comments
 (0)