diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 4ba16bcfb..500db3a3a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -16,7 +16,7 @@ _T = TypeVar("_T") -_SYNC_BRIDGE_TIMEOUT = 300 +SYNC_BRIDGE_TIMEOUT = 300 if TYPE_CHECKING: import pandas as pd @@ -42,11 +42,11 @@ def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T: future = pool.submit(asyncio.run, coro) timed_out = False try: - result = future.result(timeout=_SYNC_BRIDGE_TIMEOUT) + result = future.result(timeout=SYNC_BRIDGE_TIMEOUT) except concurrent.futures.TimeoutError as exc: timed_out = True - logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running") - raise TimeoutError(f"_run_coroutine_sync timed out after {_SYNC_BRIDGE_TIMEOUT}s") from exc + logger.warning(f"⚠️ Sync bridge timed out after {SYNC_BRIDGE_TIMEOUT}s; background thread still running") + raise TimeoutError(f"_run_coroutine_sync timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc finally: pool.shutdown(wait=not timed_out, cancel_futures=timed_out) return result diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index f420da3fe..14318791f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -6,13 +6,14 @@ from __future__ import annotations import asyncio +import concurrent.futures import inspect import logging from typing import TYPE_CHECKING, Any import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy -from data_designer.engine.column_generators.generators.base import ColumnGenerator +from data_designer.engine.column_generators.generators.base import SYNC_BRIDGE_TIMEOUT, ColumnGenerator from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError from data_designer.logging import LOG_INDENT @@ -22,6 +23,61 @@ logger = logging.getLogger(__name__) +class _AsyncBridgedModelFacade: + """Proxy that bridges ``model.generate()`` to ``model.agenerate()`` in async engine mode. + + When a sync custom column runs inside ``asyncio.to_thread`` under the async engine, + the sync HTTP client is unavailable. This proxy intercepts the resulting + ``SyncClientUnavailableError`` and schedules ``agenerate()`` on the engine's persistent + event loop via ``run_coroutine_threadsafe``. + + All other attributes are forwarded to the underlying facade unchanged. + """ + + __slots__ = ("_facade",) + + def __init__(self, facade: Any) -> None: + object.__setattr__(self, "_facade", facade) + + def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]: + from data_designer.engine.models.clients.errors import SyncClientUnavailableError + + facade = object.__getattribute__(self, "_facade") + try: + return facade.generate(*args, **kwargs) + except SyncClientUnavailableError: + pass # Fall through to async bridge + + # We're in a worker thread (asyncio.to_thread) with no running loop. + # Guard against accidental use from the event loop itself (would deadlock). + try: + asyncio.get_running_loop() + except RuntimeError: + pass # No running loop - safe to bridge + else: + raise RuntimeError( + "model.generate() is not available in async engine mode from the event loop. " + "Use 'await model.agenerate()' in async custom columns." + ) + + from data_designer.engine.dataset_builders.utils.async_concurrency import ensure_async_engine_loop + + loop = ensure_async_engine_loop() + future = asyncio.run_coroutine_threadsafe(facade.agenerate(*args, **kwargs), loop) + try: + return future.result(timeout=SYNC_BRIDGE_TIMEOUT) + except concurrent.futures.TimeoutError as exc: + future.cancel() + logger.warning("Async model bridge timed out after %ss; coroutine cancelled", SYNC_BRIDGE_TIMEOUT) + raise TimeoutError(f"model.generate() bridge timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc + + def __getattr__(self, name: str) -> Any: + return getattr(object.__getattribute__(self, "_facade"), name) + + def __repr__(self) -> str: + return f"_AsyncBridgedModelFacade({object.__getattribute__(self, '_facade')!r})" + + class CustomColumnGenerator(ColumnGenerator[CustomColumnConfig]): """Column generator that uses a user-provided callable function. @@ -273,7 +329,7 @@ def _invoke_generator_function(self, data: dict | pd.DataFrame) -> dict | pd.Dat elif len(params) == 2: return self.config.generator_function(data, self.config.generator_params) else: - models = self._build_models_dict() + models = {k: _AsyncBridgedModelFacade(v) for k, v in self._build_models_dict().items()} return self.config.generator_function(data, self.config.generator_params, models) def _build_models_dict(self) -> dict[str, Any]: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/http_model_client.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/http_model_client.py index a07305f58..7fe997f20 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/http_model_client.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/http_model_client.py @@ -14,7 +14,7 @@ resolve_timeout, wrap_transport_error, ) -from data_designer.engine.models.clients.errors import map_http_error_to_provider_error +from data_designer.engine.models.clients.errors import SyncClientUnavailableError, map_http_error_to_provider_error from data_designer.engine.models.clients.retry import RetryConfig, RetryTransport, create_retry_transport if TYPE_CHECKING: @@ -96,7 +96,7 @@ def _build_headers(self, extra_headers: dict[str, str]) -> dict[str, str]: def _get_sync_client(self) -> httpx.Client: if self._mode != ClientConcurrencyMode.SYNC: - raise RuntimeError("Sync methods are not available on an async-mode HttpModelClient.") + raise SyncClientUnavailableError("Sync methods are not available on an async-mode HttpModelClient.") with self._init_lock: if self._closed: raise RuntimeError("Model client is closed.") diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index d11f6444a..8355ce85a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -29,6 +29,10 @@ class ProviderErrorKind(str, Enum): UNSUPPORTED_CAPABILITY = "unsupported_capability" +class SyncClientUnavailableError(RuntimeError): + """Raised when sync methods are called on an async-mode HttpModelClient.""" + + class ProviderError(Exception): def __init__( self, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index 79d8a15da..95469f8bd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -11,7 +11,7 @@ from pydantic import BaseModel from data_designer.engine.errors import DataDesignerError -from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind, SyncClientUnavailableError logger = logging.getLogger(__name__) @@ -184,6 +184,10 @@ def handle_llm_exceptions( solution=f"Verify your API key for model provider and update it in your settings for model provider {model_provider_name!r}.", ) match exception: + # Let SyncClientUnavailableError propagate so the async bridge proxy can catch it + case SyncClientUnavailableError(): + raise + # Canonical ProviderError from the client adapter layer case ProviderError(): _raise_from_provider_error( diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py index 4ecfc6adc..2d73fc8dd 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py @@ -467,3 +467,109 @@ def df_func(df: pd.DataFrame) -> pd.DataFrame: gen = _create_test_generator(name="result", generator_function=df_func) with pytest.raises(CustomColumnGenerationError, match="first parameter must be 'row', got 'df'"): gen.generate({"input": 1}) + + +# Async model bridge tests + + +class TestAsyncBridgedModelFacade: + """Tests for _AsyncBridgedModelFacade proxy used by custom columns with model access.""" + + def test_proxy_transparent_in_sync_mode(self, stub_resource_provider, stub_model_facade) -> None: + """Proxy passes through generate(), forwards attributes; _build_models_dict returns raw facades.""" + from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade + + @custom_column_generator(required_columns=["input"], model_aliases=["test-model"]) + def gen_with_model(row: dict, generator_params: SampleParams, models: dict) -> dict: + row["result"] = "ok" + return row + + generator = _create_test_generator( + name="result", + generator_function=gen_with_model, + generator_params=SampleParams(), + resource_provider=stub_resource_provider, + ) + + # _build_models_dict returns raw facades (wrapping happens at the call site) + models = generator._build_models_dict() + assert not isinstance(models["test-model"], _AsyncBridgedModelFacade) + + # Proxy itself passes through generate() and forwards attributes + proxy = _AsyncBridgedModelFacade(stub_model_facade) + result, _ = proxy.generate("test", parser=str) + assert result == "Generated summary text" + stub_model_facade.generate.assert_called_once_with("test", parser=str) + assert proxy.model_alias == "test_model" + + def test_bridges_to_agenerate_on_sync_client_error(self) -> None: + """When sync generate() fails with an async/sync error, falls back to agenerate().""" + import asyncio + import threading + from unittest.mock import patch + + from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade + from data_designer.engine.models.clients.errors import SyncClientUnavailableError + + facade = Mock() + facade.generate.side_effect = SyncClientUnavailableError( + "Sync methods are not available on an async-mode HttpModelClient." + ) + + async def fake_agenerate(*args: Any, **kwargs: Any) -> tuple: + return ("async_result", list(args), kwargs) + + facade.agenerate = fake_agenerate + proxy = _AsyncBridgedModelFacade(facade) + + engine_loop = asyncio.new_event_loop() + engine_thread = threading.Thread(target=engine_loop.run_forever, daemon=True) + engine_thread.start() + + try: + with patch( + "data_designer.engine.dataset_builders.utils.async_concurrency.ensure_async_engine_loop", + return_value=engine_loop, + ): + result = proxy.generate("hello", parser=str) + assert result == ("async_result", ["hello"], {"parser": str}) + finally: + engine_loop.call_soon_threadsafe(engine_loop.stop) + engine_thread.join(timeout=5) + + def test_non_client_mode_errors_propagate(self) -> None: + """Only SyncClientUnavailableError triggers bridging; other errors propagate.""" + from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade + + # ValueError - different type entirely + facade = Mock() + facade.generate.side_effect = ValueError("invalid prompt format") + proxy = _AsyncBridgedModelFacade(facade) + with pytest.raises(ValueError, match="invalid prompt format"): + proxy.generate(prompt="hello") + + # RuntimeError - same base type as SyncClientUnavailableError, but not caught + facade = Mock() + facade.generate.side_effect = RuntimeError("connection timed out for async request") + proxy = _AsyncBridgedModelFacade(facade) + with pytest.raises(RuntimeError, match="connection timed out"): + proxy.generate(prompt="hello") + + def test_deadlock_guard_on_event_loop(self) -> None: + """Raises a clear error instead of deadlocking when called from the event loop.""" + import asyncio + + from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade + from data_designer.engine.models.clients.errors import SyncClientUnavailableError + + facade = Mock() + facade.generate.side_effect = SyncClientUnavailableError( + "Sync methods are not available on an async-mode HttpModelClient." + ) + proxy = _AsyncBridgedModelFacade(facade) + + async def call_from_loop() -> None: + proxy.generate(prompt="hello") + + with pytest.raises(RuntimeError, match="Use 'await model.agenerate\\(\\)'"): + asyncio.run(call_from_loop())