Skip to content

Commit 1ada05b

Browse files
refactor: Consolidate ConnectionSettings defaults with Config
- Renamed ConnectionConfig to ConnectionSettings - Removed duplicate _DEFAULTS dict - Added _get_default() to read defaults from Config Pydantic field definitions - Changed constructor signature: ConnectionSettings(values=dict, use_global_fallback=bool) - Updated tests to use new constructor pattern This eliminates duplication of setting definitions, validation, and defaults between Config and ConnectionSettings classes. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent f03b599 commit 1ada05b

File tree

2 files changed

+93
-113
lines changed

2 files changed

+93
-113
lines changed

src/datajoint/connection.py

Lines changed: 66 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -28,77 +28,31 @@
2828
cache_key = "query_cache" # the key to lookup the query_cache folder in dj.config
2929

3030

31-
class ConnectionConfig:
31+
class ConnectionSettings:
3232
"""
33-
Connection-scoped configuration (read/write).
33+
Connection-scoped settings accessor.
3434
35-
Provides access to settings that can vary per connection. Behavior depends on
36-
how the connection was created:
35+
Provides read/write access to settings that can vary per connection.
36+
Defaults are read from the ``Config`` class - no duplication.
3737
38-
- **New API** (``Connection.from_config()``): Uses explicit values or defaults.
39-
Never accesses global config. Works identically with ``thread_safe`` on or off.
38+
Behavior depends on how the connection was created:
39+
40+
- **New API** (``Connection.from_config()``): Uses explicit values or
41+
defaults from ``Config`` class. Never accesses global config.
4042
- **Legacy API** (``dj.conn()``): Forwards unset values to global ``dj.config``
4143
for backward compatibility.
4244
43-
Parameters
44-
----------
45-
**explicit_values : Any
46-
Explicitly provided configuration values. These take precedence over
47-
global config and defaults.
48-
49-
Attributes
50-
----------
51-
safemode : bool
52-
Require confirmation for destructive operations.
53-
database_prefix : str
54-
Prefix for schema names.
55-
stores : dict
56-
Blob storage configuration.
57-
cache : Path or None
58-
Local cache directory.
59-
query_cache : Path or None
60-
Query cache directory.
61-
reconnect : bool
62-
Auto-reconnect on lost connection.
63-
display_limit : int
64-
Max rows to display.
65-
display_width : int
66-
Column width for display.
67-
show_tuple_count : bool
68-
Show tuple count in repr.
69-
loglevel : str
70-
Logging level.
71-
filepath_checksum_size_limit : int or None
72-
Max file size for checksum.
73-
7445
Examples
7546
--------
76-
Access settings through a connection:
77-
7847
>>> conn = dj.Connection.from_config(host="localhost", user="root", password="pw")
7948
>>> conn.config.safemode
8049
True
8150
>>> conn.config.safemode = False # Disable for this connection only
8251
>>> conn.config.display_limit = 25
8352
"""
8453

85-
_DEFAULTS: dict[str, Any] = {
86-
"safemode": True,
87-
"database_prefix": "",
88-
"stores": {},
89-
"cache": None,
90-
"query_cache": None,
91-
"reconnect": True,
92-
"create_tables": True,
93-
"display_limit": 12,
94-
"display_width": 14,
95-
"show_tuple_count": True,
96-
"loglevel": "INFO",
97-
"filepath_checksum_size_limit": None,
98-
}
99-
100-
# Mapping from ConnectionConfig names to global config paths
101-
_GLOBAL_CONFIG_MAP: dict[str, str] = {
54+
# Map attribute names to global config paths (also defines valid settings)
55+
_CONFIG_PATHS: dict[str, str] = {
10256
"safemode": "safemode",
10357
"database_prefix": "database.database_prefix",
10458
"stores": "stores",
@@ -113,12 +67,9 @@ class ConnectionConfig:
11367
"filepath_checksum_size_limit": "filepath_checksum_size_limit",
11468
}
11569

116-
def __init__(self, **explicit_values: Any) -> None:
117-
object.__setattr__(self, "_values", {}) # Mutable storage for this connection
118-
# If True, forward unset values to global config (legacy API behavior)
119-
# If False, use defaults only (new API behavior)
120-
object.__setattr__(self, "_use_global_fallback", explicit_values.pop("_use_global_fallback", False))
121-
self._values.update(explicit_values)
70+
def __init__(self, values: dict[str, Any] | None = None, use_global_fallback: bool = False) -> None:
71+
object.__setattr__(self, "_values", values.copy() if values else {})
72+
object.__setattr__(self, "_use_global_fallback", use_global_fallback)
12273

12374
def __getattr__(self, name: str) -> Any:
12475
if name.startswith("_"):
@@ -130,26 +81,58 @@ def __getattr__(self, name: str) -> Any:
13081

13182
# Legacy API: forward to global config for backward compatibility
13283
if self._use_global_fallback:
133-
global_path = self._GLOBAL_CONFIG_MAP.get(name)
134-
if global_path:
135-
return config[global_path]
84+
path = self._CONFIG_PATHS.get(name)
85+
if path:
86+
return config[path]
13687

137-
# New API: use defaults only (no global config access)
138-
return self._DEFAULTS.get(name)
88+
# New API: use defaults from Config class (no duplication)
89+
return self._get_default(name)
13990

14091
def __setattr__(self, name: str, value: Any) -> None:
14192
if name.startswith("_"):
14293
return object.__setattr__(self, name, value)
143-
144-
# Store in connection-local values
14594
self._values[name] = value
14695

14796
def __repr__(self) -> str:
148-
items = []
149-
for name in self._DEFAULTS:
150-
value = getattr(self, name)
151-
items.append(f"{name}={value!r}")
152-
return f"ConnectionConfig({', '.join(items)})"
97+
items = [f"{name}={getattr(self, name)!r}" for name in self._CONFIG_PATHS]
98+
return f"ConnectionSettings({', '.join(items)})"
99+
100+
@classmethod
101+
def _get_default(cls, name: str) -> Any:
102+
"""Get default value from Config class field definitions."""
103+
from pydantic_core import PydanticUndefined
104+
105+
from .settings import Config, DatabaseSettings, DisplaySettings
106+
107+
path = cls._CONFIG_PATHS.get(name)
108+
if not path:
109+
raise AttributeError(f"Unknown connection setting: {name}")
110+
111+
parts = path.split(".")
112+
if len(parts) == 1:
113+
# Top-level field like 'safemode', 'stores'
114+
field = Config.model_fields.get(parts[0])
115+
if field is None:
116+
return None
117+
default = field.default
118+
# Handle default_factory (default is PydanticUndefined when factory is used)
119+
if (default is None or default is PydanticUndefined) and field.default_factory is not None:
120+
return field.default_factory()
121+
return default
122+
else:
123+
# Nested field like 'display.limit' or 'database.reconnect'
124+
group_name, field_name = parts
125+
group_field = Config.model_fields.get(group_name)
126+
if group_field is None:
127+
return None
128+
# Get the nested model class
129+
group_cls = {"database": DatabaseSettings, "display": DisplaySettings}.get(group_name)
130+
if group_cls is None:
131+
return None
132+
nested_field = group_cls.model_fields.get(field_name)
133+
if nested_field is None:
134+
return None
135+
return nested_field.default
153136

154137
def get_store_spec(self, store_name: str) -> dict:
155138
"""
@@ -176,7 +159,7 @@ def get_store_spec(self, store_name: str) -> dict:
176159
return stores[store_name]
177160

178161
@contextmanager
179-
def override(self, **kwargs: Any) -> Iterator["ConnectionConfig"]:
162+
def override(self, **kwargs: Any) -> Iterator["ConnectionSettings"]:
180163
"""
181164
Temporarily override configuration values for this connection.
182165
@@ -187,7 +170,7 @@ def override(self, **kwargs: Any) -> Iterator["ConnectionConfig"]:
187170
188171
Yields
189172
------
190-
ConnectionConfig
173+
ConnectionSettings
191174
The config instance with overridden values.
192175
193176
Examples
@@ -204,7 +187,7 @@ def override(self, **kwargs: Any) -> Iterator["ConnectionConfig"]:
204187
for key, value in kwargs.items():
205188
if key in self._values:
206189
backup[key] = deepcopy(self._values[key])
207-
elif key in self._DEFAULTS:
190+
elif key in self._CONFIG_PATHS:
208191
backup[key] = None # Marker for "was not set"
209192

210193
try:
@@ -368,7 +351,7 @@ class Connection:
368351
369352
Attributes
370353
----------
371-
config : ConnectionConfig
354+
config : ConnectionSettings
372355
Connection-scoped configuration settings.
373356
schemas : dict
374357
Registered schema objects.
@@ -385,7 +368,7 @@ def __init__(
385368
use_tls: bool | dict | None = None,
386369
backend: str | None = None,
387370
*,
388-
_config: ConnectionConfig | None = None,
371+
_config: ConnectionSettings | None = None,
389372
) -> None:
390373
if ":" in host:
391374
# the port in the hostname overrides the port argument
@@ -434,7 +417,7 @@ def __init__(
434417

435418
# Connection-scoped configuration
436419
# Legacy API (dj.conn()) uses global fallback for backward compatibility
437-
self.config = _config if _config is not None else ConnectionConfig(_use_global_fallback=True)
420+
self.config = _config if _config is not None else ConnectionSettings(use_global_fallback=True)
438421

439422
@classmethod
440423
def from_config(
@@ -556,7 +539,7 @@ def from_config(
556539
effective_backend = "mysql"
557540
effective_use_tls = None
558541

559-
# Connection-scoped settings (will be passed to ConnectionConfig)
542+
# Connection-scoped settings (will be passed to ConnectionSettings)
560543
config_kwargs: dict[str, Any] = {}
561544

562545
# Override with cfg dict if provided
@@ -651,11 +634,8 @@ def from_config(
651634
"Database password is required. " "Provide password= argument or include 'password' in config dict."
652635
)
653636

654-
# Create ConnectionConfig - new API never falls back to global config
655-
conn_config = ConnectionConfig(
656-
_use_global_fallback=False,
657-
**config_kwargs,
658-
)
637+
# Create ConnectionSettings - new API never falls back to global config
638+
conn_config = ConnectionSettings(values=config_kwargs, use_global_fallback=False)
659639

660640
# Create connection with explicit backend parameter and config
661641
connection = cls(

0 commit comments

Comments
 (0)