diff --git a/docs/design/thread-safe-mode.md b/docs/design/thread-safe-mode.md new file mode 100644 index 000000000..334ac3452 --- /dev/null +++ b/docs/design/thread-safe-mode.md @@ -0,0 +1,585 @@ +# Thread-Safe Mode Specification + +**Status:** Draft +**Version:** 0.1 +**Target Release:** DataJoint 2.2 +**Authors:** Dimitri Yatsenko, Claude + +## Overview + +Thread-safe mode enables DataJoint to operate in multi-tenant environments (web applications, serverless functions, multi-threaded services) where multiple users or requests share the same Python process. When enabled, global mutable state is disabled, and all connection-specific configuration becomes scoped to individual `Connection` objects. + +## Motivation + +Traditional DataJoint usage relies on global state: +- `dj.config` — singleton configuration object +- `dj.conn()` — singleton database connection + +This model works well for single-user scripts and notebooks but creates problems in: +- **Web applications** — concurrent requests from different users/tenants +- **Serverless functions** — shared runtime across invocations +- **Multi-threaded workers** — parallel processing with different credentials +- **Agentic workflows** — AI agents managing multiple database contexts + +## Design Principles + +1. **Deployment-time decision** — Thread-safe mode is set via environment or config file, not programmatically +2. **Explicit over implicit** — All connection parameters must be explicitly provided +3. **No hidden global state** — Connection behavior is fully determined by its configuration +4. **Backward compatible** — Existing code works unchanged when `thread_safe=False` + +--- + +## Configuration Categories + +### Global Config (`dj.config`) + +In thread-safe mode, `dj.config` access is blocked except for `thread_safe` itself: + +| Setting | `thread_safe=False` | `thread_safe=True` | +|---------|---------------------|-------------------| +| `thread_safe` | Read-only (set via env var or config file only) | Read-only | +| All other settings | Read/write | Raises `ThreadSafetyError` | + +### Connection-Scoped Settings (`conn.config`) + +All settings become connection-scoped and are accessed via `conn.config` (read/write): + +| Setting | Type | Default | Description | +|---------|------|---------|-------------| +| `safemode` | bool | True | Require confirmation for destructive ops | +| `database_prefix` | str | "" | Schema name prefix | +| `stores` | dict | {} | Blob storage configuration | +| `cache` | Path | None | Local cache directory | +| `query_cache` | Path | None | Query cache directory | +| `reconnect` | bool | True | Auto-reconnect on lost connection | +| `display_limit` | int | 12 | Max rows to display | +| `display_width` | int | 14 | Column width for display | +| `show_tuple_count` | bool | True | Show tuple count in repr | +| `loglevel` | str | "INFO" | Logging level | +| `filepath_checksum_size_limit` | int | None | Max file size for checksum | + +Connection parameters (set at creation, read-only after): + +| Setting | Type | Default | Description | +|---------|------|---------|-------------| +| `host` | str | "localhost" | Database hostname | +| `port` | int | 3306/5432 | Database port | +| `user` | str | *required* | Database username | +| `password` | str | *required* | Database password | +| `backend` | str | "mysql" | Database backend | +| `use_tls` | bool/dict | None | TLS configuration | + +--- + +## API Specification + +### Enabling Thread-Safe Mode + +Thread-safe mode is read-only after initialization and can only be set via environment variable or config file: + +```bash +# Method 1: Environment variable +DJ_THREAD_SAFE=true python app.py +``` + +```json +// Method 2: Config file (datajoint.json) +{ "thread_safe": true } +``` + +Programmatic setting is not allowed: +```python +dj.config.thread_safe = True # Raises ThreadSafetyError +``` + +### Global Config Access in Thread-Safe Mode + +```python +# With DJ_THREAD_SAFE=true set in environment + +# Only thread_safe is accessible +dj.config.thread_safe # OK (returns True) + +# Everything else raises ThreadSafetyError +dj.config.database.host # Raises ThreadSafetyError +dj.config.display.width # Raises ThreadSafetyError +dj.config.safemode # Raises ThreadSafetyError +``` + +### Creating Connections + +```python +# Thread-safe connection creation +conn = dj.Connection.from_config( + host="db.example.com", + user="tenant_user", + password="tenant_password", + port=3306, + backend="mysql", + safemode=False, + stores={ + "raw": { + "protocol": "s3", + "endpoint": "s3.amazonaws.com", + "bucket": "tenant-data", + "access_key": "...", + "secret_key": "...", + } + }, +) + +# Or from a configuration dict +tenant_config = { + "host": "db.example.com", + "user": request.tenant.db_user, + "password": request.tenant.db_password, + "stores": request.tenant.stores, +} +conn = dj.Connection.from_config(tenant_config) +``` + +### Accessing Connection-Scoped Settings + +Settings are accessed through `connection.config`. The connection is available on: +- `Connection` objects directly +- `Schema.connection` +- `Table.connection` (any table class: Manual, Lookup, Imported, Computed) + +```python +conn = dj.Connection.from_config(...) +schema = dj.Schema("my_pipeline", connection=conn) + +@schema +class Subject(dj.Manual): + definition = "subject_id : int" + +# All of these access the same connection-scoped config: +conn.config.safemode # Via connection directly +schema.connection.config.safemode # Via schema +Subject.connection.config.safemode # Via table class + +# Read settings +conn.config.safemode # True (default) +conn.config.database_prefix # "" +conn.config.stores # {} +conn.config.display_limit # 12 + +# Modify settings for this connection +conn.config.safemode = False +conn.config.display_limit = 25 +conn.config.stores = {"raw": {"protocol": "file", "location": "/data"}} +``` + +### Using Schemas with Connections + +```python +conn = dj.Connection.from_config(tenant_config) + +# Explicit connection binding +schema = dj.Schema("my_pipeline", connection=conn) + +@schema +class Subject(dj.Manual): + definition = """ + subject_id : int + """ + +# All operations use connection-scoped settings +Subject.insert([{"subject_id": 1}]) # Uses conn.config.safemode + +# Access config through schema or table +schema.connection.config.display_limit = 50 +Subject.connection.config.safemode # Same as conn.config.safemode +``` + +--- + +## API Compatibility Matrix + +| API | `thread_safe=False` | `thread_safe=True` | +|-----|---------------------|-------------------| +| `dj.conn()` | Works | Raises `ThreadSafetyError` | +| `dj.config.thread_safe` | Read/write | Read-only | +| `dj.config.*` (all else) | Read/write | Raises `ThreadSafetyError` | +| `Schema()` without connection | Works | Raises `ThreadSafetyError` | +| **`Connection.from_config()`** | **Works** | **Works** | +| **`conn.config.*`** | **Read/write** (forwards to global) | **Read/write** (connection-scoped) | + +The new API (`Connection.from_config()` and `conn.config`) is the **universal API** that works in both modes. + +## Backward Compatibility + +### Legacy API (thread_safe=False only) + +Existing code continues to work unchanged when `thread_safe=False`: + +```python +import datajoint as dj + +# Global config access - works +dj.config["database.host"] = "localhost" +dj.config["database.user"] = "root" +dj.config["database.password"] = "secret" + +# Singleton connection - works +conn = dj.conn() + +# Schema without explicit connection - works +schema = dj.Schema("my_schema") # Uses dj.conn() +``` + +### New API (works in both modes) + +The new API works identically whether `thread_safe` is on or off: + +```python +import datajoint as dj + +# Works with thread_safe=False OR thread_safe=True +conn = dj.Connection.from_config( + host="localhost", + user="root", + password="secret", + safemode=False, + stores={"raw": {...}}, +) + +# Access settings through connection - works in both modes +conn.config.safemode # False +conn.config.stores # {"raw": {...}} +conn.config.database_prefix # "" + +# Schema with explicit connection - works in both modes +schema = dj.Schema("my_schema", connection=conn) +``` + +### Connection.config Behavior + +The behavior of `conn.config` depends on **which API created the connection**, not on the `thread_safe` setting: + +**New API** (`Connection.from_config()`) — Uses explicit values or defaults. Never accesses global config. Works identically with `thread_safe` on or off: + +```python +dj.config.safemode = False # Set in global config +dj.config.database_prefix = "dev_" + +conn = dj.Connection.from_config( + host="localhost", + user="root", + password="secret", + # safemode not specified - uses default, NOT global config +) + +conn.config.safemode # True (default, not global) +conn.config.database_prefix # "" (default, not global) +``` + +**Legacy API** (`dj.conn()`) — Forwards unset values to global config for backward compatibility: + +```python +dj.config.safemode = False +dj.config.database_prefix = "dev_" + +conn = dj.conn() # Legacy API + +conn.config.safemode # False (from dj.config) +conn.config.database_prefix # "dev_" (from dj.config) +``` + +This design ensures that code using `Connection.from_config()` is portable and behaves identically whether `thread_safe` is enabled or not. + +## Migration Path + +Migration is immediate — adopt the new API and your code works in both modes: + +```python +# Before (legacy API - only works with thread_safe=False) +dj.config["database.host"] = "localhost" +dj.config["database.user"] = "root" +dj.config["database.password"] = "secret" +conn = dj.conn() +schema = dj.Schema("pipeline") + +# After (new API - works with thread_safe=False AND thread_safe=True) +conn = dj.Connection.from_config( + host="localhost", + user="root", + password="secret", +) +schema = dj.Schema("pipeline", connection=conn) +``` + +Once migrated to the new API, enabling `thread_safe=True` requires no code changes. + +--- + +## Implementation Details + +### Config Class Changes + +```python +class Config(BaseSettings): + def __getattribute__(self, name): + # Allow private attributes + if name.startswith("_"): + return object.__getattribute__(self, name) + + # Allow Pydantic internals + if name.startswith("model_"): + return object.__getattribute__(self, name) + + # Always allow checking thread_safe itself + if name == "thread_safe": + return object.__getattribute__(self, name) + + # Block everything else in thread-safe mode + if object.__getattribute__(self, "thread_safe"): + raise ThreadSafetyError( + f"Setting '{name}' is connection-scoped in thread-safe mode. " + "Access it via connection.config instead." + ) + + return object.__getattribute__(self, name) + + def __setattr__(self, name, value): + # Allow private attributes + if name.startswith("_"): + return object.__setattr__(self, name, value) + + # thread_safe is read-only after initialization + if name == "thread_safe": + try: + object.__getattribute__(self, "thread_safe") + # If we get here, thread_safe already exists - block the set + raise ThreadSafetyError( + "thread_safe cannot be set programmatically. " + "Set DJ_THREAD_SAFE=true in environment or datajoint.json." + ) + except AttributeError: + pass # First time setting during __init__ - allow it + return object.__setattr__(self, name, value) + + # Block everything else in thread-safe mode + if object.__getattribute__(self, "thread_safe"): + raise ThreadSafetyError( + "Global config is inaccessible in thread-safe mode. " + "Use Connection.from_config() with explicit configuration." + ) + + return object.__setattr__(self, name, value) +``` + +### Connection Class Changes + +```python +class Connection: + def __init__(self, host, user, password, port=None, + use_tls=None, backend=None, *, _config=None): + # ... existing connection setup ... + + # Connection-scoped configuration + # Legacy API (dj.conn()) uses global fallback for backward compatibility + self.config = _config if _config is not None else ConnectionConfig(_use_global_fallback=True) + + @classmethod + def from_config(cls, cfg=None, *, host=None, user=None, password=None, + port=None, backend=None, safemode=None, stores=None, + database_prefix=None, cache=None, query_cache=None, + reconnect=None, use_tls=None) -> "Connection": + """ + Create connection with explicit configuration. + + Works in both thread_safe=False and thread_safe=True modes. + """ + # ... merge cfg dict with kwargs ... + # ... validate required fields (host, user, password) ... + + # Build ConnectionConfig - new API never falls back to global config + conn_config = ConnectionConfig( + _use_global_fallback=False, + **({"safemode": safemode} if safemode is not None else {}), + **({"stores": stores} if stores is not None else {}), + **({"database_prefix": database_prefix} if database_prefix is not None else {}), + **({"cache": cache} if cache is not None else {}), + **({"query_cache": query_cache} if query_cache is not None else {}), + **({"reconnect": reconnect} if reconnect is not None else {}), + ) + + return cls( + host=effective_host, + user=effective_user, + password=effective_password, + port=effective_port, + use_tls=effective_use_tls, + backend=effective_backend, + _config=conn_config, + ) +``` + +### ConnectionConfig Class + +```python +class ConnectionConfig: + """ + Connection-scoped configuration (read/write). + + Behavior depends on how the connection was created: + - New API (from_config): Uses explicit values or defaults. Never accesses global config. + - Legacy API (dj.conn): Forwards unset values to global dj.config. + """ + + _DEFAULTS = { + "safemode": True, + "database_prefix": "", + "stores": {}, + "cache": None, + "query_cache": None, + "reconnect": True, + "display_limit": 12, + "display_width": 14, + "show_tuple_count": True, + "loglevel": "INFO", + "filepath_checksum_size_limit": None, + } + + def __init__(self, **explicit_values): + self._values = {} # Mutable storage for this connection + # If True, forward unset values to global config (legacy API behavior) + # If False, use defaults only (new API behavior) + self._use_global_fallback = explicit_values.pop("_use_global_fallback", False) + self._values.update(explicit_values) + + def __getattr__(self, name): + if name.startswith("_"): + return object.__getattribute__(self, name) + + # If set on this connection, return that value + if name in self._values: + return self._values[name] + + # Legacy API: forward to global config for backward compatibility + if self._use_global_fallback: + from .settings import config + return getattr(config, name, self._DEFAULTS.get(name)) + + # New API: use defaults only (no global config access) + return self._DEFAULTS.get(name) + + def __setattr__(self, name, value): + if name.startswith("_"): + return object.__setattr__(self, name, value) + + # Store in connection-local values + self._values[name] = value + + def get_store_spec(self, store_name: str) -> dict: + """Get store specification by name.""" + stores = self.stores + if store_name not in stores: + raise DataJointError(f"Store '{store_name}' is not configured.") + return stores[store_name] +``` + +--- + +## Error Handling + +### ThreadSafetyError + +```python +class ThreadSafetyError(DataJointError): + """ + Raised when global state is accessed in thread-safe mode. + + This error indicates that code is attempting to use global + configuration or connections that are not thread-safe. + """ +``` + +### Error Messages + +```python +# Reading blocked config +dj.config.safemode +ThreadSafetyError: Setting 'safemode' is connection-scoped in thread-safe mode. +Access it via connection.config instead. + +# Writing blocked config +dj.config.display_limit = 20 +ThreadSafetyError: Setting 'display_limit' is connection-scoped in thread-safe mode. +Modify it via connection.config instead. + +# Using dj.conn() +dj.conn() +ThreadSafetyError: dj.conn() is disabled in thread-safe mode. +Use Connection.from_config() with explicit configuration. + +# Setting thread-safe mode programmatically +dj.config.thread_safe = True +ThreadSafetyError: thread_safe cannot be set programmatically. +Set DJ_THREAD_SAFE=true in environment or datajoint.json. + +# Schema without connection +dj.Schema("my_schema") +ThreadSafetyError: Schema requires explicit connection in thread-safe mode. +Use Schema('name', connection=conn). +``` + +--- + +## Testing Strategy + +### Unit Tests + +1. **Global config in thread-safe mode** + - Verify only `thread_safe` is accessible (read-only) + - Verify all other settings raise ThreadSafetyError (read and write) + - Verify thread_safe cannot be set programmatically (only via env var or config file) + +2. **Connection.from_config()** + - Verify all parameters are accepted + - Verify defaults are applied correctly + - Verify cfg dict merging with kwargs + - Verify works in both thread_safe modes + +3. **ConnectionConfig** + - Verify read/write access to all settings + - Verify forwarding to global config when thread_safe=False + - Verify defaults used when thread_safe=True + - Verify store spec resolution + +### Integration Tests + +1. **Multi-tenant simulation** + - Create multiple connections with different configs + - Verify isolation between connections + - Verify correct store resolution per connection + +2. **Schema binding** + - Verify schemas use connection's config + - Verify safemode behavior per connection + +--- + +## Future Considerations + +### Potential Extensions + +1. **Connection pooling** — Pool of connections per tenant configuration +2. **Async support** — Async connection management for async frameworks +3. **Context managers** — Temporary connection context for specific operations + +### Out of Scope + +1. **Thread-local storage** — Rejected in favor of explicit connection passing +2. **Automatic credential rotation** — Application responsibility +3. **Multi-database transactions** — Not supported by underlying backends + +--- + +## References + +- [DataJoint Python Documentation](https://docs.datajoint.com) +- [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/) +- [WSGI/ASGI Thread Safety](https://peps.python.org/pep-3333/) diff --git a/src/datajoint/__init__.py b/src/datajoint/__init__.py index 7f809487d..ada1ef685 100644 --- a/src/datajoint/__init__.py +++ b/src/datajoint/__init__.py @@ -52,6 +52,7 @@ "errors", "migrate", "DataJointError", + "ThreadSafetyError", "logger", "cli", "ValidationResult", @@ -73,7 +74,7 @@ ) from .blob import MatCell, MatStruct from .connection import Connection, conn -from .errors import DataJointError +from .errors import DataJointError, ThreadSafetyError from .expression import AndList, Not, Top, U from .logging import logger from .objectref import ObjectRef diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py index 88339335f..bdd96b867 100644 --- a/src/datajoint/adapters/mysql.py +++ b/src/datajoint/adapters/mysql.py @@ -75,23 +75,19 @@ def connect( Password for authentication. **kwargs : Any Additional MySQL-specific parameters: - - init_command: SQL initialization command - ssl: TLS/SSL configuration dict (deprecated, use use_tls) - use_tls: bool or dict - DataJoint's SSL parameter (preferred) - - charset: Character set (default from kwargs) Returns ------- pymysql.Connection MySQL connection object. """ - init_command = kwargs.get("init_command") # Handle both ssl (old) and use_tls (new) parameter names ssl_config = kwargs.get("use_tls", kwargs.get("ssl")) # Convert boolean True to dict for PyMySQL (PyMySQL expects dict or SSLContext) if ssl_config is True: ssl_config = {} # Enable SSL with default settings - charset = kwargs.get("charset", "") # Prepare connection parameters conn_params = { @@ -99,10 +95,8 @@ def connect( "port": port, "user": user, "passwd": password, - "init_command": init_command, "sql_mode": "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", - "charset": charset, "autocommit": True, # DataJoint manages transactions explicitly } diff --git a/src/datajoint/connection.py b/src/datajoint/connection.py index 21b48e638..23beea8c2 100644 --- a/src/datajoint/connection.py +++ b/src/datajoint/connection.py @@ -11,7 +11,8 @@ import re import warnings from contextlib import contextmanager -from typing import Callable +from pathlib import Path +from typing import Any, Callable, Iterator from . import errors from .adapters import get_adapter @@ -27,6 +28,183 @@ cache_key = "query_cache" # the key to lookup the query_cache folder in dj.config +class ConnectionSettings: + """ + Connection-scoped settings accessor. + + Provides read/write access to settings that can vary per connection. + Defaults are read from the ``Config`` class - no duplication. + + Behavior depends on how the connection was created: + + - **New API** (``Connection.from_config()``): Uses explicit values or + defaults from ``Config`` class. Never accesses global config. + - **Legacy API** (``dj.conn()``): Forwards unset values to global ``dj.config`` + for backward compatibility. + + Examples + -------- + >>> conn = dj.Connection.from_config(host="localhost", user="root", password="pw") + >>> conn.config.safemode + True + >>> conn.config.safemode = False # Disable for this connection only + >>> conn.config.display_limit = 25 + """ + + # Map attribute names to global config paths (also defines valid settings) + _CONFIG_PATHS: dict[str, str] = { + "safemode": "safemode", + "database_prefix": "database.database_prefix", + "stores": "stores", + "cache": "cache", + "query_cache": "query_cache", + "reconnect": "database.reconnect", + "create_tables": "database.create_tables", + "display_limit": "display.limit", + "display_width": "display.width", + "show_tuple_count": "display.show_tuple_count", + "loglevel": "loglevel", + "filepath_checksum_size_limit": "filepath_checksum_size_limit", + } + + def __init__(self, values: dict[str, Any] | None = None, use_global_fallback: bool = False) -> None: + object.__setattr__(self, "_values", values.copy() if values else {}) + object.__setattr__(self, "_use_global_fallback", use_global_fallback) + + def __getattr__(self, name: str) -> Any: + if name.startswith("_"): + return object.__getattribute__(self, name) + + # If set on this connection, return that value + if name in self._values: + return self._values[name] + + # Legacy API: forward to global config for backward compatibility + if self._use_global_fallback: + path = self._CONFIG_PATHS.get(name) + if path: + return config[path] + + # New API: use defaults from Config class (no duplication) + return self._get_default(name) + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + return object.__setattr__(self, name, value) + self._values[name] = value + + def __repr__(self) -> str: + items = [f"{name}={getattr(self, name)!r}" for name in self._CONFIG_PATHS] + return f"ConnectionSettings({', '.join(items)})" + + @classmethod + def _get_default(cls, name: str) -> Any: + """Get default value from Config class field definitions.""" + from pydantic_core import PydanticUndefined + + from .settings import Config, DatabaseSettings, DisplaySettings + + path = cls._CONFIG_PATHS.get(name) + if not path: + raise AttributeError(f"Unknown connection setting: {name}") + + parts = path.split(".") + if len(parts) == 1: + # Top-level field like 'safemode', 'stores' + field = Config.model_fields.get(parts[0]) + if field is None: + return None + default = field.default + # Handle default_factory (default is PydanticUndefined when factory is used) + if (default is None or default is PydanticUndefined) and field.default_factory is not None: + return field.default_factory() + return default + else: + # Nested field like 'display.limit' or 'database.reconnect' + group_name, field_name = parts + group_field = Config.model_fields.get(group_name) + if group_field is None: + return None + # Get the nested model class + group_cls = {"database": DatabaseSettings, "display": DisplaySettings}.get(group_name) + if group_cls is None: + return None + nested_field = group_cls.model_fields.get(field_name) + if nested_field is None: + return None + return nested_field.default + + def get_store_spec(self, store_name: str) -> dict: + """ + Get store specification by name. + + Parameters + ---------- + store_name : str + Name of the store to retrieve. + + Returns + ------- + dict + Store specification dictionary. + + Raises + ------ + DataJointError + If the store is not configured. + """ + stores = self.stores + if store_name not in stores: + raise errors.DataJointError(f"Store '{store_name}' is not configured.") + return stores[store_name] + + @contextmanager + def override(self, **kwargs: Any) -> Iterator["ConnectionSettings"]: + """ + Temporarily override configuration values for this connection. + + Parameters + ---------- + **kwargs : Any + Settings to override. + + Yields + ------ + ConnectionSettings + The config instance with overridden values. + + Examples + -------- + >>> with conn.config.override(safemode=False, display_limit=50): + ... # conn.config.safemode is False here + ... pass + >>> # conn.config.safemode is restored + """ + from copy import deepcopy + + # Save original values + backup = {} + for key, value in kwargs.items(): + if key in self._values: + backup[key] = deepcopy(self._values[key]) + elif key in self._CONFIG_PATHS: + backup[key] = None # Marker for "was not set" + + try: + # Apply overrides + for key, value in kwargs.items(): + self._values[key] = value + yield self + finally: + # Restore original values + for key, original in backup.items(): + if original is None: + # Was not set before, remove it + self._values.pop(key, None) + else: + self._values[key] = original + + def translate_query_error(client_error: Exception, query: str, adapter) -> Exception: """ Translate client error to the corresponding DataJoint exception. @@ -55,16 +233,22 @@ def conn( user: str | None = None, password: str | None = None, *, - init_fun: Callable | None = None, reset: bool = False, use_tls: bool | dict | None = None, -) -> Connection: +) -> "Connection": """ Return a persistent connection object shared by multiple modules. If the connection is not yet established or reset=True, a new connection is set up. If connection information is not provided, it is taken from config. + .. warning:: + + This function uses global state and is not suitable for multi-tenant + applications. When ``config.thread_safe`` is True, this function raises + :exc:`~datajoint.errors.ThreadSafetyError`. Use + :meth:`Connection.from_config` instead for thread-safe connection management. + Parameters ---------- host : str, optional @@ -73,8 +257,6 @@ def conn( Database username. Required if not set in config. password : str, optional Database password. Required if not set in config. - init_fun : callable, optional - Initialization function called after connection. reset : bool, optional If True, reset existing connection. Default False. use_tls : bool or dict, optional @@ -90,7 +272,19 @@ def conn( ------ DataJointError If user or password is not provided and not set in config. + ThreadSafetyError + If ``config.thread_safe`` is True. + + See Also + -------- + Connection.from_config : Thread-safe connection creation. """ + # Check thread-safe mode + if config.thread_safe: + raise errors.ThreadSafetyError( + "dj.conn() is disabled in thread-safe mode. " "Use Connection.from_config() with explicit configuration." + ) + if not hasattr(conn, "connection") or reset: host = host if host is not None else config["database.host"] user = user if user is not None else config["database.user"] @@ -103,9 +297,8 @@ def conn( raise errors.DataJointError( "Database password not configured. Set datajoint.config['database.password'] or pass password= argument." ) - init_fun = init_fun if init_fun is not None else config["connection.init_function"] use_tls = use_tls if use_tls is not None else config["database.use_tls"] - conn.connection = Connection(host, user, password, None, init_fun, use_tls) + conn.connection = Connection(host, user, password, None, use_tls) return conn.connection @@ -150,13 +343,16 @@ class Connection: Database password. port : int, optional Port number. Overridden if specified in host. - init_fun : str, optional - SQL initialization command. use_tls : bool or dict, optional TLS encryption option. + backend : str, optional + Database backend ('mysql' or 'postgresql'). If not provided, + uses the value from global config. Attributes ---------- + config : ConnectionSettings + Connection-scoped configuration settings. schemas : dict Registered schema objects. dependencies : Dependencies @@ -169,15 +365,22 @@ def __init__( user: str, password: str, port: int | None = None, - init_fun: str | None = None, use_tls: bool | dict | None = None, + backend: str | None = None, + *, + _config: ConnectionSettings | None = None, ) -> None: if ":" in host: # the port in the hostname overrides the port argument host, port = host.split(":") port = int(port) elif port is None: - port = config["database.port"] + # In thread-safe mode, config is inaccessible, so we must use defaults + if config.thread_safe: + # Default based on backend + port = 5432 if backend == "postgresql" else 3306 + else: + port = config.database.port self.conn_info = dict(host=host, port=port, user=user, passwd=password) if use_tls is not False: # use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config) @@ -190,13 +393,16 @@ def __init__( # use_tls=True: enable SSL with default settings self.conn_info["ssl"] = True self.conn_info["ssl_input"] = use_tls - self.init_fun = init_fun self._conn = None self._query_cache = None self._is_closed = True # Mark as closed until connect() succeeds - # Select adapter based on configured backend - backend = config["database.backend"] + # Select adapter based on backend + if backend is None: + if config.thread_safe: + backend = "mysql" # Default in thread-safe mode + else: + backend = config.database.backend self.adapter = get_adapter(backend) self.connect() @@ -209,6 +415,241 @@ def __init__( self.schemas = dict() self.dependencies = Dependencies(self) + # Connection-scoped configuration + # Legacy API (dj.conn()) uses global fallback for backward compatibility + self.config = _config if _config is not None else ConnectionSettings(use_global_fallback=True) + + @classmethod + def from_config( + cls, + cfg: dict | None = None, + *, + host: str | None = None, + user: str | None = None, + password: str | None = None, + port: int | None = None, + backend: str | None = None, + use_tls: bool | dict | None = None, + # Connection-scoped settings + safemode: bool | None = None, + database_prefix: str | None = None, + stores: dict | None = None, + cache: Path | str | None = None, + query_cache: Path | str | None = None, + reconnect: bool | None = None, + display_limit: int | None = None, + display_width: int | None = None, + show_tuple_count: bool | None = None, + loglevel: str | None = None, + filepath_checksum_size_limit: int | None = None, + ) -> "Connection": + """ + Create a connection from explicit configuration. + + This is the recommended method for creating connections. It works in both + ``thread_safe=False`` and ``thread_safe=True`` modes. Unlike :func:`conn`, + this method does not require global state. + + Configuration can be provided via a dict or keyword arguments. + Keyword arguments take precedence over dict values. + + Parameters + ---------- + cfg : dict, optional + Configuration dict with connection and settings keys. + host : str, optional + Database hostname. Overrides cfg['host']. Default: 'localhost'. + user : str, optional + Database username. Overrides cfg['user']. Required. + password : str, optional + Database password. Overrides cfg['password']. Required. + port : int, optional + Database port. Overrides cfg['port']. Default: 3306 (MySQL) or 5432 (PostgreSQL). + backend : str, optional + Database backend ('mysql' or 'postgresql'). Overrides cfg['backend']. + Default: 'mysql'. + use_tls : bool or dict, optional + TLS encryption option. Overrides cfg['use_tls']. + safemode : bool, optional + Require confirmation for destructive operations. Default: True. + database_prefix : str, optional + Prefix for schema names. Default: ''. + stores : dict, optional + Blob storage configuration. Default: {}. + cache : Path or str, optional + Local cache directory. Default: None. + query_cache : Path or str, optional + Query cache directory. Default: None. + reconnect : bool, optional + Auto-reconnect on lost connection. Default: True. + display_limit : int, optional + Max rows to display. Default: 12. + display_width : int, optional + Column width for display. Default: 14. + show_tuple_count : bool, optional + Show tuple count in repr. Default: True. + loglevel : str, optional + Logging level. Default: 'INFO'. + filepath_checksum_size_limit : int, optional + Max file size for checksum. Default: None. + + Returns + ------- + Connection + A new database connection. + + Raises + ------ + DataJointError + If required parameters (user, password) are not provided. + + Examples + -------- + Create connection with explicit parameters: + + >>> conn = Connection.from_config( + ... host='localhost', + ... user='myuser', + ... password='mypassword' + ... ) + + Create connection from a config dict (e.g., from request context): + + >>> tenant_config = { + ... 'host': 'db.example.com', + ... 'user': request.user.db_user, + ... 'password': request.user.db_password, + ... } + >>> conn = Connection.from_config(tenant_config) + + Use with Schema for thread-safe pipeline access: + + >>> conn = Connection.from_config(tenant_config) + >>> schema = dj.Schema('my_pipeline', connection=conn) + + See Also + -------- + conn : Singleton connection (not available in thread-safe mode). + """ + # Start with defaults (no global config access) + effective_host = "localhost" + effective_user = None + effective_password = None + effective_port = None # Will be set based on backend + effective_backend = "mysql" + effective_use_tls = None + + # Connection-scoped settings (will be passed to ConnectionSettings) + config_kwargs: dict[str, Any] = {} + + # Override with cfg dict if provided + if cfg is not None: + # Connection parameters + if "host" in cfg: + effective_host = cfg["host"] + if "user" in cfg: + effective_user = cfg["user"] + if "password" in cfg: + effective_password = cfg["password"] + if "port" in cfg: + effective_port = cfg["port"] + if "backend" in cfg: + effective_backend = cfg["backend"] + if "use_tls" in cfg: + effective_use_tls = cfg["use_tls"] + + # Connection-scoped settings from cfg dict + if "safemode" in cfg: + config_kwargs["safemode"] = cfg["safemode"] + if "database_prefix" in cfg: + config_kwargs["database_prefix"] = cfg["database_prefix"] + if "stores" in cfg: + config_kwargs["stores"] = cfg["stores"] + if "cache" in cfg: + config_kwargs["cache"] = cfg["cache"] + if "query_cache" in cfg: + config_kwargs["query_cache"] = cfg["query_cache"] + if "reconnect" in cfg: + config_kwargs["reconnect"] = cfg["reconnect"] + if "display_limit" in cfg: + config_kwargs["display_limit"] = cfg["display_limit"] + if "display_width" in cfg: + config_kwargs["display_width"] = cfg["display_width"] + if "show_tuple_count" in cfg: + config_kwargs["show_tuple_count"] = cfg["show_tuple_count"] + if "loglevel" in cfg: + config_kwargs["loglevel"] = cfg["loglevel"] + if "filepath_checksum_size_limit" in cfg: + config_kwargs["filepath_checksum_size_limit"] = cfg["filepath_checksum_size_limit"] + + # Override with explicit keyword arguments (connection params) + if host is not None: + effective_host = host + if user is not None: + effective_user = user + if password is not None: + effective_password = password + if port is not None: + effective_port = port + if backend is not None: + effective_backend = backend + if use_tls is not None: + effective_use_tls = use_tls + + # Override with explicit keyword arguments (connection-scoped settings) + if safemode is not None: + config_kwargs["safemode"] = safemode + if database_prefix is not None: + config_kwargs["database_prefix"] = database_prefix + if stores is not None: + config_kwargs["stores"] = stores + if cache is not None: + config_kwargs["cache"] = cache + if query_cache is not None: + config_kwargs["query_cache"] = query_cache + if reconnect is not None: + config_kwargs["reconnect"] = reconnect + if display_limit is not None: + config_kwargs["display_limit"] = display_limit + if display_width is not None: + config_kwargs["display_width"] = display_width + if show_tuple_count is not None: + config_kwargs["show_tuple_count"] = show_tuple_count + if loglevel is not None: + config_kwargs["loglevel"] = loglevel + if filepath_checksum_size_limit is not None: + config_kwargs["filepath_checksum_size_limit"] = filepath_checksum_size_limit + + # Set default port based on backend if not specified + if effective_port is None: + effective_port = 5432 if effective_backend == "postgresql" else 3306 + + # Validate required fields + if effective_user is None: + raise errors.DataJointError( + "Database user is required. " "Provide user= argument or include 'user' in config dict." + ) + if effective_password is None: + raise errors.DataJointError( + "Database password is required. " "Provide password= argument or include 'password' in config dict." + ) + + # Create ConnectionSettings - new API never falls back to global config + conn_config = ConnectionSettings(values=config_kwargs, use_global_fallback=False) + + # Create connection with explicit backend parameter and config + connection = cls( + host=effective_host, + user=effective_user, + password=effective_password, + port=effective_port, + use_tls=effective_use_tls, + backend=effective_backend, + _config=conn_config, + ) + + return connection + def __eq__(self, other): return self.conn_info == other.conn_info @@ -227,8 +668,6 @@ def connect(self) -> None: port=self.conn_info["port"], user=self.conn_info["user"], password=self.conn_info["passwd"], - init_command=self.init_fun, - charset=config["connection.charset"], use_tls=self.conn_info.get("ssl"), ) except Exception as ssl_error: @@ -244,8 +683,6 @@ def connect(self) -> None: port=self.conn_info["port"], user=self.conn_info["user"], password=self.conn_info["passwd"], - init_command=self.init_fun, - charset=config["connection.charset"], use_tls=False, # Explicitly disable SSL for fallback ) else: @@ -271,6 +708,9 @@ def set_query_cache(self, query_cache: str | None = None) -> None: def purge_query_cache(self) -> None: """Delete all cached query results.""" + if config.thread_safe: + # Query caching requires global config; not supported in thread-safe mode + return if isinstance(config.get(cache_key), str) and pathlib.Path(config[cache_key]).is_dir(): for path in pathlib.Path(config[cache_key]).iterdir(): if not path.is_dir(): @@ -426,7 +866,7 @@ def query( return EmulatedCursor(unpack(buffer)) if reconnect is None: - reconnect = config["database.reconnect"] + reconnect = self.config.reconnect logger.debug("Executing SQL:" + query[:query_log_max_length]) cursor = self.adapter.get_cursor(self._conn, as_dict=as_dict) try: diff --git a/src/datajoint/errors.py b/src/datajoint/errors.py index 7e10f021d..35eab65ec 100644 --- a/src/datajoint/errors.py +++ b/src/datajoint/errors.py @@ -72,3 +72,13 @@ class MissingExternalFile(DataJointError): class BucketInaccessible(DataJointError): """S3 bucket is inaccessible.""" + + +class ThreadSafetyError(DataJointError): + """ + Raised when global state is accessed in thread-safe mode. + + When ``config.thread_safe`` is True, accessing global configuration + or the singleton connection via ``dj.conn()`` raises this error. + Use ``Connection.from_config()`` instead for thread-safe connection management. + """ diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index 2955fd67d..e94edcb5b 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any from .connection import conn -from .errors import AccessError, DataJointError +from .errors import AccessError, DataJointError, ThreadSafetyError if TYPE_CHECKING: from .connection import Connection @@ -68,7 +68,7 @@ class Schema: context : dict, optional Namespace for foreign key lookup. None uses caller's context. connection : Connection, optional - Database connection. Defaults to ``dj.conn()``. + Database connection. Defaults to ``dj.conn()``. Required in thread-safe mode. create_schema : bool, optional If False, raise error if schema doesn't exist. Default True. create_tables : bool, optional @@ -85,6 +85,11 @@ class Schema: ... definition = ''' ... session_id : int ... ''' + + In thread-safe mode, connection must be explicit: + + >>> conn = dj.Connection.from_config(host='localhost', user='root', password='pw') + >>> schema = dj.Schema('my_schema', connection=conn) """ def __init__( @@ -120,7 +125,7 @@ def __init__( self.database = None self.context = context self.create_schema = create_schema - self.create_tables = create_tables if create_tables is not None else config.database.create_tables + self._create_tables = create_tables # Store explicit value (may be None) self.add_objects = add_objects self.declare_list = [] if schema_name: @@ -130,6 +135,20 @@ def is_activated(self) -> bool: """Check if the schema has been activated.""" return self.database is not None + @property + def create_tables(self) -> bool: + """Whether to create tables automatically when accessed.""" + if self._create_tables is not None: + return self._create_tables + if self.connection is None: + raise DataJointError("Cannot access create_tables before schema has a connection.") + return self.connection.config.create_tables + + @create_tables.setter + def create_tables(self, value: bool) -> None: + """Set explicit create_tables value.""" + self._create_tables = value + def activate( self, schema_name: str | None = None, @@ -174,6 +193,12 @@ def activate( if connection is not None: self.connection = connection if self.connection is None: + if config.thread_safe: + raise ThreadSafetyError( + "Schema requires explicit connection in thread-safe mode. " + "Use Schema('name', connection=conn) where conn is created via " + "Connection.from_config()." + ) self.connection = conn() self.database = schema_name if create_schema is not None: diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index e373ca38f..6b384dcca 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -50,7 +50,7 @@ from pydantic import Field, SecretStr, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from .errors import DataJointError +from .errors import DataJointError, ThreadSafetyError CONFIG_FILENAME = "datajoint.json" SECRETS_DIRNAME = ".secrets" @@ -219,15 +219,6 @@ def set_default_port_from_backend(self) -> "DatabaseSettings": return self -class ConnectionSettings(BaseSettings): - """Connection behavior settings.""" - - model_config = SettingsConfigDict(extra="forbid", validate_assignment=True) - - init_function: str | None = None - charset: str = "" # pymysql uses '' as default - - class DisplaySettings(BaseSettings): """Display and preview settings.""" @@ -326,7 +317,6 @@ class Config(BaseSettings): # Nested settings groups database: DatabaseSettings = Field(default_factory=DatabaseSettings) - connection: ConnectionSettings = Field(default_factory=ConnectionSettings) display: DisplaySettings = Field(default_factory=DisplaySettings) jobs: JobsSettings = Field(default_factory=JobsSettings) @@ -341,7 +331,13 @@ class Config(BaseSettings): # Top-level settings loglevel: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(default="INFO", validation_alias="DJ_LOG_LEVEL") safemode: bool = True - enable_python_native_blobs: bool = True + thread_safe: bool = Field( + default=False, + validation_alias="DJ_THREAD_SAFE", + description="Thread-safe mode. When True, global config access is blocked. " + "Read-only after initialization: set via DJ_THREAD_SAFE env var or datajoint.json. " + "Use Connection.from_config() for thread-safe connections.", + ) filepath_checksum_size_limit: int | None = None # Cache paths @@ -640,7 +636,11 @@ def _update_from_flat_dict(self, data: dict[str, Any]) -> None: if env_var and os.environ.get(env_var): logger.debug(f"Skipping {key} from file (env var {env_var} takes precedence)") continue - setattr(self, key, value) + # thread_safe is read-only after init, but we allow setting from config file + if key == "thread_safe": + object.__setattr__(self, key, value) + else: + setattr(self, key, value) elif len(parts) == 2: group, attr = parts if hasattr(self, group): @@ -818,10 +818,6 @@ def save_template( "reconnect": True, "use_tls": None, }, - "connection": { - "init_function": None, - "charset": "", - }, "display": { "limit": 12, "width": 14, @@ -844,7 +840,6 @@ def save_template( }, "loglevel": "INFO", "safemode": True, - "enable_python_native_blobs": True, "cache": None, "query_cache": None, "download_path": ".", @@ -883,9 +878,85 @@ def save_template( return filepath.absolute() + def __getattribute__(self, name: str) -> Any: + """ + Override attribute access to block all access in thread-safe mode. + + Raises + ------ + ThreadSafetyError + If thread-safe mode is enabled (except for 'thread_safe' itself). + """ + # Always allow access to dunder methods and private attributes + if name.startswith("_"): + return object.__getattribute__(self, name) + + # Always allow Pydantic model methods (model_post_init, model_dump, etc.) + if name.startswith("model_"): + return object.__getattribute__(self, name) + + # Always allow checking thread_safe itself (to know which mode we're in) + if name == "thread_safe": + return object.__getattribute__(self, name) + + # Check thread-safe mode (use object.__getattribute__ to avoid recursion) + if object.__getattribute__(self, "thread_safe"): + raise ThreadSafetyError( + "Global config is inaccessible in thread-safe mode. " + "Use Connection.from_config() with explicit configuration." + ) + + return object.__getattribute__(self, name) + + def __setattr__(self, name: str, value: Any) -> None: + """ + Override attribute setting to enforce thread-safe mode rules. + + Raises + ------ + ThreadSafetyError + If thread-safe mode is enabled and trying to modify config, + or if trying to set thread_safe programmatically. + """ + # Always allow setting private attributes (pydantic internals) + if name.startswith("_"): + return object.__setattr__(self, name, value) + + # thread_safe is read-only after initialization + if name == "thread_safe": + # Allow setting during __init__ (when attribute doesn't exist yet) + try: + object.__getattribute__(self, "thread_safe") + # If we get here, thread_safe already exists - block the set + raise ThreadSafetyError( + "thread_safe cannot be set programmatically. " "Set DJ_THREAD_SAFE=true in environment or datajoint.json." + ) + except AttributeError: + pass # First time setting during __init__ - allow it + return object.__setattr__(self, name, value) + + # Block all other modifications in thread-safe mode + try: + if object.__getattribute__(self, "thread_safe"): + raise ThreadSafetyError( + "Global config is inaccessible in thread-safe mode. " + "Use Connection.from_config() with explicit configuration." + ) + except AttributeError: + pass # thread_safe not set yet (during __init__) + + # Use super().__setattr__ to preserve Pydantic validation + return super().__setattr__(name, value) + # Dict-like access for convenience def __getitem__(self, key: str) -> Any: """Get setting by dot-notation key (e.g., 'database.host').""" + # Allow checking thread_safe itself + if key != "thread_safe" and self.thread_safe: + raise ThreadSafetyError( + "Global config is inaccessible in thread-safe mode. " + "Use Connection.from_config() with explicit configuration." + ) parts = key.split(".") obj: Any = self for part in parts: @@ -902,6 +973,17 @@ def __getitem__(self, key: str) -> Any: def __setitem__(self, key: str, value: Any) -> None: """Set setting by dot-notation key (e.g., 'database.host').""" + # thread_safe is read-only - cannot be set programmatically + if key == "thread_safe": + raise ThreadSafetyError( + "thread_safe cannot be set programmatically. " "Set DJ_THREAD_SAFE=true in environment or datajoint.json." + ) + + if self.thread_safe: + raise ThreadSafetyError( + "Global config is inaccessible in thread-safe mode. " + "Use Connection.from_config() with explicit configuration." + ) parts = key.split(".") if len(parts) == 1: if hasattr(self, key): diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py index 20fa3233d..5a9203dca 100644 --- a/tests/integration/test_jobs.py +++ b/tests/integration/test_jobs.py @@ -108,10 +108,9 @@ def test_sigterm(clean_jobs, schema_any): def test_suppress_dj_errors(clean_jobs, schema_any): - """Test that DataJoint errors are suppressible without native py blobs.""" + """Test that DataJoint errors are suppressible.""" error_class = schema.ErrorClass() - with dj.config.override(enable_python_native_blobs=False): - error_class.populate(reserve_jobs=True, suppress_errors=True) + error_class.populate(reserve_jobs=True, suppress_errors=True) assert len(schema.DjExceptionName()) == len(error_class.jobs.errors) > 0 diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index af5718503..c1effc4dc 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -561,7 +561,6 @@ def test_save_full_template(self, tmp_path): # Full template should have all settings groups assert "database" in content - assert "connection" in content assert "display" in content assert "stores" in content assert "loglevel" in content @@ -868,3 +867,83 @@ def test_backend_field_in_env_var_mapping(self): assert "database.backend" in ENV_VAR_MAPPING assert ENV_VAR_MAPPING["database.backend"] == "DJ_BACKEND" + + +class TestThreadSafeMode: + """Tests for thread-safe configuration mode.""" + + @pytest.fixture(autouse=True) + def reset_thread_safe(self): + """Reset thread_safe before and after each test.""" + from datajoint import settings + + object.__setattr__(settings.config, "thread_safe", False) + yield + object.__setattr__(settings.config, "thread_safe", False) + + def test_thread_safe_default_false(self): + """Thread-safe mode is disabled by default.""" + from datajoint.settings import Config + + cfg = Config() + assert cfg.thread_safe is False + + def test_thread_safe_cannot_be_set_programmatically(self): + """Thread-safe mode cannot be set programmatically.""" + from datajoint import settings + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError, match="cannot be set programmatically"): + settings.config.thread_safe = True + + def test_thread_safe_cannot_be_set_via_setitem(self): + """Thread-safe mode cannot be set via dict access.""" + from datajoint import settings + from datajoint.errors import ThreadSafetyError + + with pytest.raises(ThreadSafetyError, match="cannot be set programmatically"): + settings.config["thread_safe"] = True + + def test_getitem_blocked_in_thread_safe_mode(self): + """Dict-like config access raises ThreadSafetyError in thread-safe mode.""" + from datajoint import settings + from datajoint.errors import ThreadSafetyError + + object.__setattr__(settings.config, "thread_safe", True) + with pytest.raises(ThreadSafetyError, match="Global config is inaccessible"): + _ = settings.config["database.host"] + + def test_setitem_blocked_in_thread_safe_mode(self): + """Dict-like config modification raises ThreadSafetyError in thread-safe mode.""" + from datajoint import settings + from datajoint.errors import ThreadSafetyError + + object.__setattr__(settings.config, "thread_safe", True) + with pytest.raises(ThreadSafetyError, match="Global config is inaccessible"): + settings.config["database.host"] = "newhost" + + def test_attribute_access_blocked_in_thread_safe_mode(self): + """Attribute access is also blocked in thread-safe mode.""" + from datajoint import settings + from datajoint.errors import ThreadSafetyError + + object.__setattr__(settings.config, "thread_safe", True) + with pytest.raises(ThreadSafetyError, match="Global config is inaccessible"): + _ = settings.config.database + + def test_thread_safe_always_readable(self): + """The thread_safe setting itself is always readable.""" + from datajoint import settings + + object.__setattr__(settings.config, "thread_safe", True) + # Should not raise + assert settings.config.thread_safe is True + assert settings.config["thread_safe"] is True + + def test_private_attributes_accessible(self): + """Private attributes (starting with _) are accessible in thread-safe mode.""" + from datajoint import settings + + object.__setattr__(settings.config, "thread_safe", True) + # Private attributes should be accessible for internal operations + _ = settings.config._config_path # Should not raise diff --git a/tests/unit/test_thread_safe.py b/tests/unit/test_thread_safe.py new file mode 100644 index 000000000..a57b2e8d2 --- /dev/null +++ b/tests/unit/test_thread_safe.py @@ -0,0 +1,511 @@ +"""Tests for thread-safe mode in connection management.""" + +import pytest + +import datajoint as dj +from datajoint.connection import ConnectionSettings +from datajoint.errors import ThreadSafetyError + + +@pytest.fixture(autouse=True) +def reset_thread_safe_mode(): + """Reset thread_safe to False before and after each test.""" + # Use object.__setattr__ to bypass read-only restriction for test reset + object.__setattr__(dj.config, "thread_safe", False) + yield + object.__setattr__(dj.config, "thread_safe", False) + + +def enable_thread_safe(): + """Helper to enable thread-safe mode in tests (bypasses read-only).""" + object.__setattr__(dj.config, "thread_safe", True) + + +class TestThreadSafeModeSetting: + """Tests for thread_safe as a read-only setting.""" + + def test_thread_safe_default_false(self): + """Thread-safe mode is disabled by default.""" + assert dj.config.thread_safe is False + + def test_thread_safe_cannot_be_set_programmatically(self): + """Thread-safe mode cannot be set via attribute assignment.""" + with pytest.raises(ThreadSafetyError, match="cannot be set programmatically"): + dj.config.thread_safe = True + + def test_thread_safe_cannot_be_set_via_dict_access(self): + """Thread-safe mode cannot be set via dict-style access.""" + with pytest.raises(ThreadSafetyError, match="cannot be set programmatically"): + dj.config["thread_safe"] = True + + def test_thread_safe_from_env_var(self, monkeypatch): + """Thread-safe mode can be set via environment variable.""" + from datajoint.settings import Config + + monkeypatch.setenv("DJ_THREAD_SAFE", "true") + cfg = Config() + assert cfg.thread_safe is True + + def test_thread_safe_from_config_file(self, tmp_path): + """Thread-safe mode can be set via config file.""" + import json + + from datajoint.settings import Config + + config_file = tmp_path / "datajoint.json" + config_file.write_text(json.dumps({"thread_safe": True})) + cfg = Config() + cfg.load(config_file) + assert cfg.thread_safe is True + + +class TestConfigBlockedInThreadSafeMode: + """Tests for config access being blocked in thread-safe mode.""" + + def test_attribute_access_blocked(self): + """Attribute access raises ThreadSafetyError in thread-safe mode.""" + enable_thread_safe() + with pytest.raises(ThreadSafetyError, match="Global config is inaccessible"): + _ = dj.config.database + + def test_dict_access_blocked(self): + """Dict-style access raises ThreadSafetyError in thread-safe mode.""" + enable_thread_safe() + with pytest.raises(ThreadSafetyError, match="Global config is inaccessible"): + _ = dj.config["database.host"] + + def test_dict_set_blocked(self): + """Dict-style setting raises ThreadSafetyError in thread-safe mode.""" + enable_thread_safe() + with pytest.raises(ThreadSafetyError, match="Global config is inaccessible"): + dj.config["database.host"] = "newhost" + + def test_attribute_set_blocked(self): + """Attribute setting raises ThreadSafetyError in thread-safe mode.""" + enable_thread_safe() + with pytest.raises(ThreadSafetyError, match="Global config is inaccessible"): + dj.config.safemode = False + + def test_thread_safe_always_readable(self): + """The thread_safe setting itself is always readable.""" + enable_thread_safe() + # Should not raise + assert dj.config.thread_safe is True + assert dj.config["thread_safe"] is True + + +class TestConnBlockedInThreadSafeMode: + """Tests for dj.conn() being blocked in thread-safe mode.""" + + def test_conn_blocked(self): + """dj.conn() raises ThreadSafetyError in thread-safe mode.""" + enable_thread_safe() + with pytest.raises(ThreadSafetyError, match="dj.conn\\(\\) is disabled"): + dj.conn() + + +class TestConnectionFromConfig: + """Tests for Connection.from_config() method.""" + + def test_from_config_exists(self): + """Connection.from_config class method exists.""" + assert hasattr(dj.Connection, "from_config") + assert callable(dj.Connection.from_config) + + def test_from_config_requires_user(self): + """from_config raises error if user not provided.""" + with pytest.raises(dj.DataJointError, match="user is required"): + dj.Connection.from_config({"host": "localhost", "password": "test"}) + + def test_from_config_requires_password(self): + """from_config raises error if password not provided.""" + with pytest.raises(dj.DataJointError, match="password is required"): + dj.Connection.from_config({"host": "localhost", "user": "test"}) + + def test_from_config_with_explicit_params(self): + """from_config accepts explicit keyword parameters.""" + from unittest.mock import patch + + captured_args = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_args["host"] = host + captured_args["user"] = user + captured_args["port"] = port + + with patch.object(dj.Connection, "__init__", mock_init): + dj.Connection.from_config( + host="testhost", + user="testuser", + password="testpass", + port=3307, + ) + + assert captured_args["host"] == "testhost" + assert captured_args["user"] == "testuser" + assert captured_args["port"] == 3307 + + def test_from_config_with_dict(self): + """from_config accepts configuration dict.""" + from unittest.mock import patch + + cfg = { + "host": "dicthost", + "user": "dictuser", + "password": "dictpass", + "port": 3308, + } + + captured_args = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_args["host"] = host + captured_args["port"] = port + + with patch.object(dj.Connection, "__init__", mock_init): + dj.Connection.from_config(cfg) + + assert captured_args["host"] == "dicthost" + assert captured_args["port"] == 3308 + + def test_from_config_kwargs_override_dict(self): + """Keyword arguments override dict values.""" + from unittest.mock import patch + + cfg = {"host": "dicthost", "user": "dictuser", "password": "dictpass"} + captured_args = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_args["host"] = host + captured_args["user"] = user + + with patch.object(dj.Connection, "__init__", mock_init): + dj.Connection.from_config(cfg, host="overridehost") + + assert captured_args["host"] == "overridehost" + assert captured_args["user"] == "dictuser" + + def test_from_config_works_in_thread_safe_mode(self): + """from_config works in thread-safe mode (no global config access).""" + from unittest.mock import patch + + enable_thread_safe() + + captured_args = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_args["host"] = host + + with patch.object(dj.Connection, "__init__", mock_init): + # Should NOT raise ThreadSafetyError + dj.Connection.from_config( + host="testhost", + user="testuser", + password="testpass", + ) + + assert captured_args["host"] == "testhost" + + def test_from_config_default_port_mysql(self): + """from_config uses default port 3306 for MySQL.""" + from unittest.mock import patch + + captured_args = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_args["port"] = port + captured_args["backend"] = backend + + with patch.object(dj.Connection, "__init__", mock_init): + dj.Connection.from_config(host="h", user="u", password="p") + + assert captured_args["port"] == 3306 + assert captured_args["backend"] == "mysql" + + def test_from_config_default_port_postgresql(self): + """from_config uses default port 5432 for PostgreSQL.""" + from unittest.mock import patch + + captured_args = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_args["port"] = port + + with patch.object(dj.Connection, "__init__", mock_init): + dj.Connection.from_config(host="h", user="u", password="p", backend="postgresql") + + assert captured_args["port"] == 5432 + + +class TestThreadSafetyErrorExport: + """Tests for ThreadSafetyError availability.""" + + def test_error_exported(self): + """ThreadSafetyError is exported from datajoint module.""" + assert hasattr(dj, "ThreadSafetyError") + assert dj.ThreadSafetyError is ThreadSafetyError + + def test_error_is_subclass(self): + """ThreadSafetyError is a subclass of DataJointError.""" + assert issubclass(ThreadSafetyError, dj.DataJointError) + + +class TestConnectionSettings: + """Tests for ConnectionSettings class.""" + + def test_defaults(self): + """ConnectionSettings has correct defaults.""" + cfg = ConnectionSettings() + assert cfg.safemode is True + assert cfg.database_prefix == "" + assert cfg.stores == {} + assert cfg.cache is None + assert cfg.reconnect is True + assert cfg.display_limit == 12 + assert cfg.display_width == 14 + + def test_explicit_values(self): + """Explicit values override defaults.""" + cfg = ConnectionSettings(values={"safemode": False, "display_limit": 25, "stores": {"raw": {}}}) + assert cfg.safemode is False + assert cfg.display_limit == 25 + assert cfg.stores == {"raw": {}} + + def test_read_write(self): + """ConnectionSettings supports read/write access.""" + cfg = ConnectionSettings() + cfg.safemode = False + cfg.display_limit = 50 + assert cfg.safemode is False + assert cfg.display_limit == 50 + + def test_forwarding_to_global_with_legacy_api(self): + """Unset values forward to global config with legacy API (dj.conn()).""" + # Set a value in global config + original_safemode = dj.config.safemode + object.__setattr__(dj.config, "safemode", False) + + try: + # Legacy API uses use_global_fallback=True + cfg = ConnectionSettings(use_global_fallback=True) + # Should forward to global config + assert cfg.safemode is False + finally: + object.__setattr__(dj.config, "safemode", original_safemode) + + def test_uses_defaults_with_new_api(self): + """Unset values use defaults with new API (from_config()).""" + # New API uses use_global_fallback=False + cfg = ConnectionSettings(use_global_fallback=False) + # Should use default, not global config + assert cfg.safemode is True # default + assert cfg.display_limit == 12 # default + + def test_new_api_works_identically_regardless_of_thread_safe(self): + """New API (from_config) uses defaults, not global config, in both modes.""" + # Set different values in global config + original_safemode = dj.config.safemode + object.__setattr__(dj.config, "safemode", False) # Different from default (True) + + try: + # New API with thread_safe=False + cfg1 = ConnectionSettings(use_global_fallback=False) + + # Enable thread_safe mode + enable_thread_safe() + + # New API with thread_safe=True + cfg2 = ConnectionSettings(use_global_fallback=False) + + # Both should use defaults, not global config + assert cfg1.safemode is True # default, not global (False) + assert cfg2.safemode is True # default, not global (False) + assert cfg1.safemode == cfg2.safemode + finally: + object.__setattr__(dj.config, "safemode", original_safemode) + + def test_explicit_overrides_global_with_legacy_api(self): + """Explicit values override global config even with legacy API.""" + original_safemode = dj.config.safemode + object.__setattr__(dj.config, "safemode", True) + + try: + cfg = ConnectionSettings(values={"safemode": False}, use_global_fallback=True) + assert cfg.safemode is False # explicit value + finally: + object.__setattr__(dj.config, "safemode", original_safemode) + + def test_get_store_spec(self): + """get_store_spec returns store configuration.""" + cfg = ConnectionSettings(values={"stores": {"raw": {"protocol": "file", "location": "/data"}}}) + spec = cfg.get_store_spec("raw") + assert spec["protocol"] == "file" + assert spec["location"] == "/data" + + def test_get_store_spec_not_found(self): + """get_store_spec raises error for unknown store.""" + cfg = ConnectionSettings(values={"stores": {}}) + with pytest.raises(dj.DataJointError, match="not configured"): + cfg.get_store_spec("unknown") + + def test_repr(self): + """ConnectionSettings has informative repr.""" + cfg = ConnectionSettings(values={"safemode": False}) + r = repr(cfg) + assert "ConnectionSettings" in r + assert "safemode=False" in r + + def test_override_context_manager(self): + """override temporarily changes values and restores them.""" + cfg = ConnectionSettings(values={"safemode": True, "display_limit": 10}) + + with cfg.override(safemode=False, display_limit=50): + assert cfg.safemode is False + assert cfg.display_limit == 50 + + assert cfg.safemode is True + assert cfg.display_limit == 10 + + def test_override_restores_on_exception(self): + """override restores values even when exception is raised.""" + cfg = ConnectionSettings(values={"safemode": True}) + + try: + with cfg.override(safemode=False): + assert cfg.safemode is False + raise RuntimeError("test error") + except RuntimeError: + pass + + assert cfg.safemode is True + + def test_override_with_defaults(self): + """override works when value was not explicitly set.""" + cfg = ConnectionSettings() # Uses defaults + assert cfg.safemode is True # default + + with cfg.override(safemode=False): + assert cfg.safemode is False + + # Should restore to default (not be in _values) + assert cfg.safemode is True + assert "safemode" not in cfg._values + + +class TestConnectionSettingsAttribute: + """Tests for Connection.config attribute.""" + + def test_from_config_creates_connection_config(self): + """from_config creates ConnectionSettings on connection.""" + from unittest.mock import patch + + captured_config = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_config["config"] = _config + + with patch.object(dj.Connection, "__init__", mock_init): + dj.Connection.from_config(host="h", user="u", password="p", safemode=False) + + assert captured_config["config"] is not None + assert isinstance(captured_config["config"], ConnectionSettings) + assert captured_config["config"].safemode is False + + def test_from_config_passes_all_settings(self): + """from_config passes all connection-scoped settings.""" + from unittest.mock import patch + + captured_config = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_config["config"] = _config + + with patch.object(dj.Connection, "__init__", mock_init): + dj.Connection.from_config( + host="h", + user="u", + password="p", + safemode=False, + database_prefix="test_", + display_limit=100, + stores={"main": {"protocol": "file"}}, + ) + + cfg = captured_config["config"] + assert cfg.safemode is False + assert cfg.database_prefix == "test_" + assert cfg.display_limit == 100 + assert cfg.stores == {"main": {"protocol": "file"}} + + def test_from_config_extracts_settings_from_dict(self): + """from_config extracts connection-scoped settings from cfg dict.""" + from unittest.mock import patch + + captured_config = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_config["config"] = _config + + cfg_dict = { + "host": "h", + "user": "u", + "password": "p", + "safemode": False, + "display_limit": 50, + } + + with patch.object(dj.Connection, "__init__", mock_init): + dj.Connection.from_config(cfg_dict) + + cfg = captured_config["config"] + assert cfg.safemode is False + assert cfg.display_limit == 50 + + def test_from_config_does_not_use_global_fallback(self): + """from_config creates config that doesn't fall back to global config.""" + from unittest.mock import patch + + # Set a non-default value in global config + original_safemode = dj.config.safemode + object.__setattr__(dj.config, "safemode", False) # Different from default (True) + + captured_config = {} + + def mock_init(self, host, user, password, port=None, use_tls=None, backend=None, *, _config=None): + captured_config["config"] = _config + + try: + with patch.object(dj.Connection, "__init__", mock_init): + # Don't pass safemode - should use default, not global + dj.Connection.from_config(host="h", user="u", password="p") + + cfg = captured_config["config"] + # Should use default (True), not global (False) + assert cfg.safemode is True + finally: + object.__setattr__(dj.config, "safemode", original_safemode) + + +class TestSchemaThreadSafe: + """Tests for Schema behavior in thread-safe mode.""" + + def test_schema_without_connection_raises_in_thread_safe_mode(self): + """Schema without explicit connection raises ThreadSafetyError.""" + enable_thread_safe() + with pytest.raises(ThreadSafetyError, match="Schema requires explicit connection"): + dj.Schema("test_schema") + + def test_schema_with_connection_works_in_thread_safe_mode(self): + """Schema with explicit connection works in thread-safe mode.""" + from unittest.mock import MagicMock, patch + + enable_thread_safe() + + # Create a mock connection with new API config (no global fallback) + mock_conn = MagicMock(spec=dj.Connection) + mock_conn.config = ConnectionSettings(use_global_fallback=False) + + # Mock the schema activation to avoid database operations + with patch.object(dj.Schema, "activate"): + schema = dj.Schema("test_schema", connection=mock_conn) + assert schema.connection is mock_conn