Skip to content

Commit bdbec08

Browse files
kushalbakshiclaude
andcommitted
refactor: extract _build_connect_kwargs, add dbname to Connection.__init__, add tests
- Extract duplicated connect kwargs construction into _build_connect_kwargs() - Add dbname as explicit keyword argument to Connection.__init__() for programmatic use (explicit arg overrides config value) - Add 5 unit tests for dbname settings (default, env var, config file, dict access, override context manager) - Bump version to 2.2.1 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 58de0f3 commit bdbec08

File tree

3 files changed

+74
-24
lines changed

3 files changed

+74
-24
lines changed

src/datajoint/connection.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
port: int | None = None,
169169
use_tls: bool | dict | None = None,
170170
*,
171+
dbname: str | None = None,
171172
backend: str | None = None,
172173
config_override: "Config | None" = None,
173174
) -> None:
@@ -180,7 +181,8 @@ def __init__(
180181
port = int(port)
181182
elif port is None:
182183
port = self._config["database.port"]
183-
dbname = self._config.get("database.dbname")
184+
if dbname is None:
185+
dbname = self._config.get("database.dbname")
184186
self.conn_info = dict(host=host, port=port, user=user, passwd=password, dbname=dbname)
185187
if use_tls is not False:
186188
# use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config)
@@ -219,23 +221,26 @@ def __repr__(self):
219221
connected = "connected" if self.is_connected else "disconnected"
220222
return "DataJoint connection ({connected}) {user}@{host}:{port}".format(connected=connected, **self.conn_info)
221223

224+
def _build_connect_kwargs(self, use_tls=None):
225+
"""Build kwargs dict for adapter.connect()."""
226+
kwargs = dict(
227+
host=self.conn_info["host"],
228+
port=self.conn_info["port"],
229+
user=self.conn_info["user"],
230+
password=self.conn_info["passwd"],
231+
charset=self._config["connection.charset"],
232+
use_tls=use_tls if use_tls is not None else self.conn_info.get("ssl"),
233+
)
234+
if self.conn_info.get("dbname"):
235+
kwargs["dbname"] = self.conn_info["dbname"]
236+
return kwargs
237+
222238
def connect(self) -> None:
223239
"""Establish or re-establish connection to the database server."""
224240
with warnings.catch_warnings():
225241
warnings.filterwarnings("ignore", ".*deprecated.*")
226242
try:
227-
# Use adapter to create connection
228-
connect_kwargs = dict(
229-
host=self.conn_info["host"],
230-
port=self.conn_info["port"],
231-
user=self.conn_info["user"],
232-
password=self.conn_info["passwd"],
233-
charset=self._config["connection.charset"],
234-
use_tls=self.conn_info.get("ssl"),
235-
)
236-
if self.conn_info.get("dbname"):
237-
connect_kwargs["dbname"] = self.conn_info["dbname"]
238-
self._conn = self.adapter.connect(**connect_kwargs)
243+
self._conn = self.adapter.connect(**self._build_connect_kwargs())
239244
except Exception as ssl_error:
240245
# If SSL fails, retry without SSL (if it was auto-detected)
241246
if self.conn_info.get("ssl_input") is None:
@@ -244,17 +249,9 @@ def connect(self) -> None:
244249
"To require SSL, set use_tls=True explicitly.",
245250
ssl_error,
246251
)
247-
connect_kwargs = dict(
248-
host=self.conn_info["host"],
249-
port=self.conn_info["port"],
250-
user=self.conn_info["user"],
251-
password=self.conn_info["passwd"],
252-
charset=self._config["connection.charset"],
253-
use_tls=False, # Explicitly disable SSL for fallback
252+
self._conn = self.adapter.connect(
253+
**self._build_connect_kwargs(use_tls=False)
254254
)
255-
if self.conn_info.get("dbname"):
256-
connect_kwargs["dbname"] = self.conn_info["dbname"]
257-
self._conn = self.adapter.connect(**connect_kwargs)
258255
else:
259256
raise
260257
self._is_closed = False # Mark as connected after successful connection

src/datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# version bump auto managed by Github Actions:
22
# label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit)
33
# manually set this version will be eventually overwritten by the above actions
4-
__version__ = "2.2.0"
4+
__version__ = "2.2.1"

tests/unit/test_settings.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,59 @@ def test_similar_prefix_names_allowed(self):
750750
dj.config.stores.update(original_stores)
751751

752752

753+
class TestDbnameConfiguration:
754+
"""Test database.dbname configuration."""
755+
756+
def test_dbname_default_is_none(self):
757+
"""Dbname defaults to None when not configured."""
758+
from datajoint.settings import DatabaseSettings
759+
760+
s = DatabaseSettings()
761+
assert s.dbname is None
762+
763+
def test_dbname_env_var(self, monkeypatch):
764+
"""DJ_DBNAME environment variable sets dbname."""
765+
from datajoint.settings import DatabaseSettings
766+
767+
monkeypatch.setenv("DJ_DBNAME", "my_database")
768+
s = DatabaseSettings()
769+
assert s.dbname == "my_database"
770+
771+
def test_dbname_from_config_file(self, tmp_path, monkeypatch):
772+
"""Load dbname from config file."""
773+
import json
774+
775+
from datajoint.settings import Config
776+
777+
config_file = tmp_path / "test_config.json"
778+
config_file.write_text(json.dumps({
779+
"database": {"dbname": "custom_db", "host": "localhost"}
780+
}))
781+
782+
monkeypatch.delenv("DJ_DBNAME", raising=False)
783+
monkeypatch.delenv("DJ_HOST", raising=False)
784+
785+
cfg = Config()
786+
cfg.load(config_file)
787+
assert cfg.database.dbname == "custom_db"
788+
789+
def test_dbname_dict_access(self):
790+
"""Dict-style access reads and writes dbname."""
791+
original = dj.config.database.dbname
792+
try:
793+
dj.config.database.dbname = "test_db"
794+
assert dj.config["database.dbname"] == "test_db"
795+
finally:
796+
dj.config.database.dbname = original
797+
798+
def test_dbname_override_context_manager(self):
799+
"""Override context manager temporarily sets dbname."""
800+
original = dj.config.database.dbname
801+
with dj.config.override(database__dbname="override_db"):
802+
assert dj.config.database.dbname == "override_db"
803+
assert dj.config.database.dbname == original
804+
805+
753806
class TestBackendConfiguration:
754807
"""Test database backend configuration and port auto-detection."""
755808

0 commit comments

Comments
 (0)