Skip to content

Commit bf7c442

Browse files
feat: add backend parameter to Instance and cross-connection validation
Instance now accepts backend="mysql"|"postgresql" to explicitly set the database backend, with automatic port default derivation (3306 vs 5432). Join, restriction, and union operators now validate that both operands use the same connection, raising DataJointError with a clear message when expressions from different Instances are combined. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent da43ed3 commit bf7c442

File tree

4 files changed

+141
-4
lines changed

4 files changed

+141
-4
lines changed

src/datajoint/condition.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,13 @@ def assert_join_compatibility(
244244
if isinstance(expr1, U) or isinstance(expr2, U):
245245
return
246246

247+
# Check that both expressions use the same connection
248+
if expr1.connection is not expr2.connection:
249+
raise DataJointError(
250+
"Cannot operate on expressions from different connections. "
251+
"Ensure both operands use the same dj.Instance or global connection."
252+
)
253+
247254
if semantic_check:
248255
# Check if lineage tracking is available for both expressions
249256
if not expr1.heading.lineage_available or not expr2.heading.lineage_available:

src/datajoint/expression.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,8 +1414,11 @@ def create(cls, arg1, arg2):
14141414
arg2 = arg2() # instantiate if a class
14151415
if not isinstance(arg2, QueryExpression):
14161416
raise DataJointError("A QueryExpression can only be unioned with another QueryExpression")
1417-
if arg1.connection != arg2.connection:
1418-
raise DataJointError("Cannot operate on QueryExpressions originating from different connections.")
1417+
if arg1.connection is not arg2.connection:
1418+
raise DataJointError(
1419+
"Cannot operate on expressions from different connections. "
1420+
"Ensure both operands use the same dj.Instance or global connection."
1421+
)
14191422
if set(arg1.primary_key) != set(arg2.primary_key):
14201423
raise DataJointError("The operands of a union must share the same primary key.")
14211424
if set(arg1.heading.secondary_attributes) & set(arg2.heading.secondary_attributes):

src/datajoint/instance.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99

1010
import os
11-
from typing import TYPE_CHECKING, Any
11+
from typing import TYPE_CHECKING, Any, Literal
1212

1313
from .connection import Connection
1414
from .errors import ThreadSafetyError
@@ -54,9 +54,11 @@ class Instance:
5454
password : str
5555
Database password.
5656
port : int, optional
57-
Database port. Default from config or 3306.
57+
Database port. Defaults to 3306 for MySQL, 5432 for PostgreSQL.
5858
use_tls : bool or dict, optional
5959
TLS configuration.
60+
backend : str, optional
61+
Database backend: ``"mysql"`` or ``"postgresql"``. Default from config.
6062
**kwargs : Any
6163
Additional config overrides applied to this instance's config.
6264
@@ -81,11 +83,19 @@ def __init__(
8183
password: str,
8284
port: int | None = None,
8385
use_tls: bool | dict | None = None,
86+
backend: Literal["mysql", "postgresql"] | None = None,
8487
**kwargs: Any,
8588
) -> None:
8689
# Create fresh config with defaults loaded from env/file
8790
self.config = _create_config()
8891

92+
# Apply backend override before other kwargs (port default depends on it)
93+
if backend is not None:
94+
self.config.database.backend = backend
95+
# Re-derive port default since _create_config resolved it before backend was set
96+
if port is None and "database__port" not in kwargs:
97+
self.config.database.port = 5432 if backend == "postgresql" else 3306
98+
8999
# Apply any config overrides from kwargs
90100
for key, value in kwargs.items():
91101
if hasattr(self.config, key):

tests/unit/test_thread_safe.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,123 @@ def test_instance_always_allowed_in_thread_safe_mode(self, monkeypatch):
155155
assert callable(Instance)
156156

157157

158+
class TestInstanceBackend:
159+
"""Test Instance backend parameter."""
160+
161+
def test_instance_backend_sets_config(self, monkeypatch):
162+
"""Instance(backend=...) sets config.database.backend."""
163+
monkeypatch.setenv("DJ_THREAD_SAFE", "false")
164+
from datajoint.instance import Instance
165+
from unittest.mock import patch
166+
167+
with patch("datajoint.instance.Connection"):
168+
inst = Instance(
169+
host="localhost", user="root", password="secret",
170+
backend="postgresql",
171+
)
172+
assert inst.config.database.backend == "postgresql"
173+
174+
def test_instance_backend_default_from_config(self, monkeypatch):
175+
"""Instance without backend uses config default."""
176+
monkeypatch.setenv("DJ_THREAD_SAFE", "false")
177+
from datajoint.instance import Instance
178+
from unittest.mock import patch
179+
180+
with patch("datajoint.instance.Connection"):
181+
inst = Instance(
182+
host="localhost", user="root", password="secret",
183+
)
184+
assert inst.config.database.backend == "mysql"
185+
186+
def test_instance_backend_affects_port_default(self, monkeypatch):
187+
"""Instance(backend='postgresql') uses port 5432 by default."""
188+
monkeypatch.setenv("DJ_THREAD_SAFE", "false")
189+
from datajoint.instance import Instance
190+
from unittest.mock import patch, call
191+
192+
with patch("datajoint.instance.Connection") as MockConn:
193+
Instance(
194+
host="localhost", user="root", password="secret",
195+
backend="postgresql",
196+
)
197+
# Connection should be called with port 5432 (PostgreSQL default)
198+
args, kwargs = MockConn.call_args
199+
assert args[3] == 5432 # port is the 4th positional arg
200+
201+
202+
class TestCrossConnectionValidation:
203+
"""Test that cross-connection operations are rejected."""
204+
205+
def test_join_different_connections_raises(self):
206+
"""Join of expressions from different connections raises DataJointError."""
207+
from datajoint.expression import QueryExpression
208+
from datajoint.errors import DataJointError
209+
from unittest.mock import MagicMock
210+
211+
expr1 = QueryExpression()
212+
expr1._connection = MagicMock()
213+
expr1._heading = MagicMock()
214+
expr1._heading.names = []
215+
216+
expr2 = QueryExpression()
217+
expr2._connection = MagicMock() # different connection object
218+
expr2._heading = MagicMock()
219+
expr2._heading.names = []
220+
221+
with pytest.raises(DataJointError, match="different connections"):
222+
expr1 * expr2
223+
224+
def test_join_same_connection_allowed(self):
225+
"""Join of expressions from the same connection does not raise."""
226+
from datajoint.condition import assert_join_compatibility
227+
from datajoint.expression import QueryExpression
228+
from unittest.mock import MagicMock
229+
230+
shared_conn = MagicMock()
231+
232+
expr1 = QueryExpression()
233+
expr1._connection = shared_conn
234+
expr1._heading = MagicMock()
235+
expr1._heading.names = []
236+
expr1._heading.lineage_available = False
237+
238+
expr2 = QueryExpression()
239+
expr2._connection = shared_conn
240+
expr2._heading = MagicMock()
241+
expr2._heading.names = []
242+
expr2._heading.lineage_available = False
243+
244+
# Should not raise
245+
assert_join_compatibility(expr1, expr2)
246+
247+
def test_restriction_different_connections_raises(self):
248+
"""Restriction by expression from different connection raises DataJointError."""
249+
from datajoint.expression import QueryExpression
250+
from datajoint.errors import DataJointError
251+
from unittest.mock import MagicMock
252+
253+
expr1 = QueryExpression()
254+
expr1._connection = MagicMock()
255+
expr1._heading = MagicMock()
256+
expr1._heading.names = ["a"]
257+
expr1._heading.__getitem__ = MagicMock()
258+
expr1._heading.new_attributes = set()
259+
expr1._support = ["`db`.`t1`"]
260+
expr1._restriction = []
261+
expr1._restriction_attributes = set()
262+
expr1._joins = []
263+
expr1._top = None
264+
expr1._original_heading = expr1._heading
265+
266+
expr2 = QueryExpression()
267+
expr2._connection = MagicMock() # different connection
268+
expr2._heading = MagicMock()
269+
expr2._heading.names = ["a"]
270+
271+
with pytest.raises(DataJointError, match="different connections"):
272+
expr1 & expr2
273+
274+
158275
class TestThreadSafetyError:
159276
"""Test ThreadSafetyError exception."""
160277

0 commit comments

Comments
 (0)