Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

_T = TypeVar("_T")

_SYNC_BRIDGE_TIMEOUT = 300
SYNC_BRIDGE_TIMEOUT = 300

if TYPE_CHECKING:
import pandas as pd
Expand All @@ -42,11 +42,11 @@ def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T:
future = pool.submit(asyncio.run, coro)
timed_out = False
try:
result = future.result(timeout=_SYNC_BRIDGE_TIMEOUT)
result = future.result(timeout=SYNC_BRIDGE_TIMEOUT)
except concurrent.futures.TimeoutError as exc:
timed_out = True
logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running")
raise TimeoutError(f"_run_coroutine_sync timed out after {_SYNC_BRIDGE_TIMEOUT}s") from exc
logger.warning(f"⚠️ Sync bridge timed out after {SYNC_BRIDGE_TIMEOUT}s; background thread still running")
raise TimeoutError(f"_run_coroutine_sync timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc
finally:
pool.shutdown(wait=not timed_out, cancel_futures=timed_out)
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import inspect
import logging
from typing import TYPE_CHECKING, Any

import data_designer.lazy_heavy_imports as lazy
from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy
from data_designer.engine.column_generators.generators.base import ColumnGenerator
from data_designer.engine.column_generators.generators.base import SYNC_BRIDGE_TIMEOUT, ColumnGenerator
from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError
from data_designer.logging import LOG_INDENT

Expand All @@ -22,6 +23,61 @@
logger = logging.getLogger(__name__)


class _AsyncBridgedModelFacade:
"""Proxy that bridges ``model.generate()`` to ``model.agenerate()`` in async engine mode.

When a sync custom column runs inside ``asyncio.to_thread`` under the async engine,
the sync HTTP client is unavailable. This proxy intercepts the resulting
``SyncClientUnavailableError`` and schedules ``agenerate()`` on the engine's persistent
event loop via ``run_coroutine_threadsafe``.

All other attributes are forwarded to the underlying facade unchanged.
"""

__slots__ = ("_facade",)

def __init__(self, facade: Any) -> None:
object.__setattr__(self, "_facade", facade)

def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]:
from data_designer.engine.models.clients.errors import SyncClientUnavailableError

facade = object.__getattribute__(self, "_facade")
try:
return facade.generate(*args, **kwargs)
except SyncClientUnavailableError:
pass # Fall through to async bridge

# We're in a worker thread (asyncio.to_thread) with no running loop.
# Guard against accidental use from the event loop itself (would deadlock).
try:
asyncio.get_running_loop()
except RuntimeError:
pass # No running loop - safe to bridge
else:
raise RuntimeError(
"model.generate() is not available in async engine mode from the event loop. "
"Use 'await model.agenerate()' in async custom columns."
)

from data_designer.engine.dataset_builders.utils.async_concurrency import ensure_async_engine_loop

loop = ensure_async_engine_loop()
future = asyncio.run_coroutine_threadsafe(facade.agenerate(*args, **kwargs), loop)
try:
return future.result(timeout=SYNC_BRIDGE_TIMEOUT)
except concurrent.futures.TimeoutError as exc:
future.cancel()
logger.warning("Async model bridge timed out after %ss; coroutine cancelled", SYNC_BRIDGE_TIMEOUT)
raise TimeoutError(f"model.generate() bridge timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc

def __getattr__(self, name: str) -> Any:
return getattr(object.__getattribute__(self, "_facade"), name)

def __repr__(self) -> str:
return f"_AsyncBridgedModelFacade({object.__getattribute__(self, '_facade')!r})"


class CustomColumnGenerator(ColumnGenerator[CustomColumnConfig]):
"""Column generator that uses a user-provided callable function.

Expand Down Expand Up @@ -273,7 +329,7 @@ def _invoke_generator_function(self, data: dict | pd.DataFrame) -> dict | pd.Dat
elif len(params) == 2:
return self.config.generator_function(data, self.config.generator_params)
else:
models = self._build_models_dict()
models = {k: _AsyncBridgedModelFacade(v) for k, v in self._build_models_dict().items()}
return self.config.generator_function(data, self.config.generator_params, models)

def _build_models_dict(self) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
resolve_timeout,
wrap_transport_error,
)
from data_designer.engine.models.clients.errors import map_http_error_to_provider_error
from data_designer.engine.models.clients.errors import SyncClientUnavailableError, map_http_error_to_provider_error
from data_designer.engine.models.clients.retry import RetryConfig, RetryTransport, create_retry_transport

if TYPE_CHECKING:
Expand Down Expand Up @@ -96,7 +96,7 @@ def _build_headers(self, extra_headers: dict[str, str]) -> dict[str, str]:

def _get_sync_client(self) -> httpx.Client:
if self._mode != ClientConcurrencyMode.SYNC:
raise RuntimeError("Sync methods are not available on an async-mode HttpModelClient.")
raise SyncClientUnavailableError("Sync methods are not available on an async-mode HttpModelClient.")
with self._init_lock:
if self._closed:
raise RuntimeError("Model client is closed.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class ProviderErrorKind(str, Enum):
UNSUPPORTED_CAPABILITY = "unsupported_capability"


class SyncClientUnavailableError(RuntimeError):
"""Raised when sync methods are called on an async-mode HttpModelClient."""


class ProviderError(Exception):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import BaseModel

from data_designer.engine.errors import DataDesignerError
from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind
from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind, SyncClientUnavailableError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -184,6 +184,10 @@ def handle_llm_exceptions(
solution=f"Verify your API key for model provider and update it in your settings for model provider {model_provider_name!r}.",
)
match exception:
# Let SyncClientUnavailableError propagate so the async bridge proxy can catch it
case SyncClientUnavailableError():
raise

# Canonical ProviderError from the client adapter layer
case ProviderError():
_raise_from_provider_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,109 @@ def df_func(df: pd.DataFrame) -> pd.DataFrame:
gen = _create_test_generator(name="result", generator_function=df_func)
with pytest.raises(CustomColumnGenerationError, match="first parameter must be 'row', got 'df'"):
gen.generate({"input": 1})


# Async model bridge tests


class TestAsyncBridgedModelFacade:
"""Tests for _AsyncBridgedModelFacade proxy used by custom columns with model access."""

def test_proxy_transparent_in_sync_mode(self, stub_resource_provider, stub_model_facade) -> None:
"""Proxy passes through generate(), forwards attributes; _build_models_dict returns raw facades."""
from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade

@custom_column_generator(required_columns=["input"], model_aliases=["test-model"])
def gen_with_model(row: dict, generator_params: SampleParams, models: dict) -> dict:
row["result"] = "ok"
return row

generator = _create_test_generator(
name="result",
generator_function=gen_with_model,
generator_params=SampleParams(),
resource_provider=stub_resource_provider,
)

# _build_models_dict returns raw facades (wrapping happens at the call site)
models = generator._build_models_dict()
assert not isinstance(models["test-model"], _AsyncBridgedModelFacade)

# Proxy itself passes through generate() and forwards attributes
proxy = _AsyncBridgedModelFacade(stub_model_facade)
result, _ = proxy.generate("test", parser=str)
assert result == "Generated summary text"
stub_model_facade.generate.assert_called_once_with("test", parser=str)
assert proxy.model_alias == "test_model"

def test_bridges_to_agenerate_on_sync_client_error(self) -> None:
"""When sync generate() fails with an async/sync error, falls back to agenerate()."""
import asyncio
import threading
from unittest.mock import patch

from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade
from data_designer.engine.models.clients.errors import SyncClientUnavailableError

facade = Mock()
facade.generate.side_effect = SyncClientUnavailableError(
"Sync methods are not available on an async-mode HttpModelClient."
)

async def fake_agenerate(*args: Any, **kwargs: Any) -> tuple:
return ("async_result", list(args), kwargs)

facade.agenerate = fake_agenerate
proxy = _AsyncBridgedModelFacade(facade)

engine_loop = asyncio.new_event_loop()
engine_thread = threading.Thread(target=engine_loop.run_forever, daemon=True)
engine_thread.start()

try:
with patch(
"data_designer.engine.dataset_builders.utils.async_concurrency.ensure_async_engine_loop",
return_value=engine_loop,
):
result = proxy.generate("hello", parser=str)
assert result == ("async_result", ["hello"], {"parser": str})
finally:
engine_loop.call_soon_threadsafe(engine_loop.stop)
engine_thread.join(timeout=5)

def test_non_client_mode_errors_propagate(self) -> None:
"""Only SyncClientUnavailableError triggers bridging; other errors propagate."""
from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade

# ValueError - different type entirely
facade = Mock()
facade.generate.side_effect = ValueError("invalid prompt format")
proxy = _AsyncBridgedModelFacade(facade)
with pytest.raises(ValueError, match="invalid prompt format"):
proxy.generate(prompt="hello")

# RuntimeError - same base type as SyncClientUnavailableError, but not caught
facade = Mock()
facade.generate.side_effect = RuntimeError("connection timed out for async request")
proxy = _AsyncBridgedModelFacade(facade)
with pytest.raises(RuntimeError, match="connection timed out"):
proxy.generate(prompt="hello")

def test_deadlock_guard_on_event_loop(self) -> None:
"""Raises a clear error instead of deadlocking when called from the event loop."""
import asyncio

from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade
from data_designer.engine.models.clients.errors import SyncClientUnavailableError

facade = Mock()
facade.generate.side_effect = SyncClientUnavailableError(
"Sync methods are not available on an async-mode HttpModelClient."
)
proxy = _AsyncBridgedModelFacade(facade)

async def call_from_loop() -> None:
proxy.generate(prompt="hello")

with pytest.raises(RuntimeError, match="Use 'await model.agenerate\\(\\)'"):
asyncio.run(call_from_loop())
Loading