Skip to content

Commit 9b885cd

Browse files
refactor: make dj.Schema a proper class subclassing _Schema
Rename schemas.Schema → schemas._Schema (internal) and define dj.Schema as a class that inherits from _Schema with a thread-safety check in __init__. This eliminates the confusing function-vs-class duality where dj.Schema was a function wrapper and schemas.Schema was the class. Now dj.Schema is a real class: isinstance, hasattr, and subclass checks all work naturally. The test for rebuild_lineage can use dj.Schema directly instead of importing from datajoint.schemas. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b88915c commit 9b885cd

File tree

6 files changed

+48
-44
lines changed

6 files changed

+48
-44
lines changed

src/datajoint/__init__.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
from .instance import Instance, _ConfigProxy, _get_singleton_connection, _global_config, _check_thread_safe
8181
from .logging import logger
8282
from .objectref import ObjectRef
83-
from .schemas import Schema as _Schema, VirtualModule, list_schemas, virtual_schema
83+
from .schemas import _Schema, VirtualModule, list_schemas, virtual_schema
8484
from .table import FreeTable as _FreeTable, Table, ValidationResult
8585
from .user_tables import Computed, Imported, Lookup, Manual, Part
8686
from .version import __version__
@@ -166,26 +166,20 @@ def conn(
166166
return _get_singleton_connection()
167167

168168

169-
def Schema(
170-
schema_name: str | None = None,
171-
context: dict | None = None,
172-
*,
173-
connection: Connection | None = None,
174-
create_schema: bool = True,
175-
create_tables: bool | None = None,
176-
add_objects: dict | None = None,
177-
) -> _Schema:
169+
class Schema(_Schema):
178170
"""
179-
Create a Schema for binding table classes to a database schema.
171+
Decorator that binds table classes to a database schema.
180172
181173
When connection is not provided, uses the singleton connection.
174+
In thread-safe mode (``DJ_THREAD_SAFE=true``), a connection must be
175+
provided explicitly or use ``dj.Instance().Schema()`` instead.
182176
183177
Parameters
184178
----------
185179
schema_name : str, optional
186-
Database schema name.
180+
Database schema name. If omitted, call ``activate()`` later.
187181
context : dict, optional
188-
Namespace for foreign key lookup.
182+
Namespace for foreign key lookup. None uses caller's context.
189183
connection : Connection, optional
190184
Database connection. Defaults to singleton connection.
191185
create_schema : bool, optional
@@ -195,29 +189,41 @@ def Schema(
195189
add_objects : dict, optional
196190
Additional objects for declaration context.
197191
198-
Returns
199-
-------
200-
Schema
201-
A Schema bound to the specified connection.
202-
203192
Raises
204193
------
205194
ThreadSafetyError
206-
If thread_safe mode is enabled and using singleton.
195+
If thread_safe mode is enabled and no connection is provided.
196+
197+
Examples
198+
--------
199+
>>> schema = dj.Schema('my_schema')
200+
>>> @schema
201+
... class Session(dj.Manual):
202+
... definition = '''
203+
... session_id : int
204+
... '''
207205
"""
208-
if connection is None:
209-
# Use singleton connection - will raise ThreadSafetyError if thread_safe=True
210-
_check_thread_safe()
211-
connection = _get_singleton_connection()
212-
213-
return _Schema(
214-
schema_name,
215-
context=context,
216-
connection=connection,
217-
create_schema=create_schema,
218-
create_tables=create_tables,
219-
add_objects=add_objects,
220-
)
206+
207+
def __init__(
208+
self,
209+
schema_name: str | None = None,
210+
context: dict | None = None,
211+
*,
212+
connection: Connection | None = None,
213+
create_schema: bool = True,
214+
create_tables: bool | None = None,
215+
add_objects: dict | None = None,
216+
) -> None:
217+
if connection is None:
218+
_check_thread_safe()
219+
super().__init__(
220+
schema_name,
221+
context=context,
222+
connection=connection,
223+
create_schema=create_schema,
224+
create_tables=create_tables,
225+
add_objects=add_objects,
226+
)
221227

222228

223229
def FreeTable(conn_or_name, full_table_name: str | None = None) -> _FreeTable:

src/datajoint/gc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from .errors import DataJointError
4545

4646
if TYPE_CHECKING:
47-
from .schemas import Schema
47+
from .schemas import _Schema as Schema
4848

4949
logger = logging.getLogger(__name__.split(".")[0])
5050

src/datajoint/instance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .settings import Config, _create_config, config as _settings_config
1616

1717
if TYPE_CHECKING:
18-
from .schemas import Schema as SchemaClass
18+
from .schemas import _Schema as SchemaClass
1919
from .table import FreeTable as FreeTableClass
2020

2121

@@ -140,9 +140,9 @@ def Schema(
140140
Schema
141141
A Schema using this instance's connection.
142142
"""
143-
from .schemas import Schema
143+
from .schemas import _Schema
144144

145-
return Schema(
145+
return _Schema(
146146
schema_name,
147147
context=context,
148148
connection=self.connection,

src/datajoint/migrate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
)
4040

4141
if TYPE_CHECKING:
42-
from .schemas import Schema
42+
from .schemas import _Schema as Schema
4343

4444
logger = logging.getLogger(__name__.split(".")[0])
4545

@@ -653,7 +653,7 @@ def add_job_metadata_columns(target, dry_run: bool = True) -> dict:
653653
- Future populate() calls will fill in metadata for new rows
654654
- This does NOT retroactively populate metadata for existing rows
655655
"""
656-
from .schemas import Schema
656+
from .schemas import _Schema
657657
from .table import Table
658658

659659
result = {
@@ -664,7 +664,7 @@ def add_job_metadata_columns(target, dry_run: bool = True) -> dict:
664664
}
665665

666666
# Determine tables to process
667-
if isinstance(target, Schema):
667+
if isinstance(target, _Schema):
668668
schema = target
669669
# Get all user tables in the schema
670670
tables_query = """

src/datajoint/schemas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def ordered_dir(class_: type) -> list[str]:
5353
return attr_list
5454

5555

56-
class Schema:
56+
class _Schema:
5757
"""
5858
Decorator that binds table classes to a database schema.
5959
@@ -832,7 +832,7 @@ def __init__(
832832
Additional objects to add to the module namespace.
833833
"""
834834
super(VirtualModule, self).__init__(name=module_name)
835-
_schema = Schema(
835+
_schema = _Schema(
836836
schema_name,
837837
create_schema=create_schema,
838838
create_tables=create_tables,

tests/integration/test_semantic_matching.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,7 @@ class TestRebuildLineageUtility:
325325

326326
def test_rebuild_lineage_method_exists(self):
327327
"""The rebuild_lineage method should exist on Schema."""
328-
from datajoint.schemas import Schema as _Schema
329-
330-
assert hasattr(_Schema, "rebuild_lineage")
328+
assert hasattr(dj.Schema, "rebuild_lineage")
331329

332330
def test_rebuild_lineage_populates_table(self, schema_semantic):
333331
"""schema.rebuild_lineage() should populate the ~lineage table."""

0 commit comments

Comments
 (0)