Skip to content

Commit f41ae14

Browse files
committed
feat: bridge model.generate() to agenerate() for custom columns in async engine
Custom column generators that call model.generate() fail under the async engine because the sync HTTP client is unavailable. Add an _AsyncBridgedModelFacade proxy in _build_models_dict() that intercepts the sync-client RuntimeError and schedules agenerate() on the engine's persistent event loop via run_coroutine_threadsafe. Includes a deadlock guard for async custom columns running on the event loop.
1 parent 64f31bc commit f41ae14

2 files changed

Lines changed: 160 additions & 3 deletions

File tree

  • packages/data-designer-engine
    • src/data_designer/engine/column_generators/generators
    • tests/engine/column_generators/generators

packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import data_designer.lazy_heavy_imports as lazy
1414
from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy
15-
from data_designer.engine.column_generators.generators.base import ColumnGenerator
15+
from data_designer.engine.column_generators.generators.base import _SYNC_BRIDGE_TIMEOUT, ColumnGenerator
1616
from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError
1717
from data_designer.logging import LOG_INDENT
1818

@@ -22,6 +22,57 @@
2222
logger = logging.getLogger(__name__)
2323

2424

25+
class _AsyncBridgedModelFacade:
26+
"""Proxy that bridges ``model.generate()`` to ``model.agenerate()`` in async engine mode.
27+
28+
When a sync custom column runs inside ``asyncio.to_thread`` under the async engine,
29+
the sync HTTP client is unavailable. This proxy intercepts the resulting error and
30+
schedules ``agenerate()`` on the engine's persistent event loop via
31+
``run_coroutine_threadsafe``.
32+
33+
All other attributes are forwarded to the underlying facade unchanged.
34+
"""
35+
36+
_SYNC_CLIENT_ERROR = "Sync methods are not available on an async-mode HttpModelClient."
37+
38+
__slots__ = ("_facade",)
39+
40+
def __init__(self, facade: Any) -> None:
41+
object.__setattr__(self, "_facade", facade)
42+
43+
def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]:
44+
facade = object.__getattribute__(self, "_facade")
45+
try:
46+
return facade.generate(*args, **kwargs)
47+
except RuntimeError as exc:
48+
if str(exc) != self._SYNC_CLIENT_ERROR:
49+
raise
50+
51+
# We're in a worker thread (asyncio.to_thread) with no running loop.
52+
# Guard against accidental use from the event loop itself (would deadlock).
53+
try:
54+
asyncio.get_running_loop()
55+
except RuntimeError:
56+
pass # No running loop - safe to bridge
57+
else:
58+
raise RuntimeError(
59+
"model.generate() is not available in async engine mode from the event loop. "
60+
"Use 'await model.agenerate()' in async custom columns."
61+
)
62+
63+
from data_designer.engine.dataset_builders.utils.async_concurrency import ensure_async_engine_loop
64+
65+
loop = ensure_async_engine_loop()
66+
future = asyncio.run_coroutine_threadsafe(facade.agenerate(*args, **kwargs), loop)
67+
return future.result(timeout=_SYNC_BRIDGE_TIMEOUT)
68+
69+
def __getattr__(self, name: str) -> Any:
70+
return getattr(object.__getattribute__(self, "_facade"), name)
71+
72+
def __repr__(self) -> str:
73+
return f"_AsyncBridgedModelFacade({object.__getattribute__(self, '_facade')!r})"
74+
75+
2576
class CustomColumnGenerator(ColumnGenerator[CustomColumnConfig]):
2677
"""Column generator that uses a user-provided callable function.
2778
@@ -277,9 +328,13 @@ def _invoke_generator_function(self, data: dict | pd.DataFrame) -> dict | pd.Dat
277328
return self.config.generator_function(data, self.config.generator_params, models)
278329

279330
def _build_models_dict(self) -> dict[str, Any]:
280-
"""Build a dict of ModelFacade instances from model_aliases."""
331+
"""Build a dict of ModelFacade instances from model_aliases.
332+
333+
Facades are wrapped in ``_AsyncBridgedModelFacade`` so that sync custom
334+
columns can call ``model.generate()`` transparently under the async engine.
335+
"""
281336
return {
282-
alias: self.resource_provider.model_registry.get_model(model_alias=alias)
337+
alias: _AsyncBridgedModelFacade(self.resource_provider.model_registry.get_model(model_alias=alias))
283338
for alias in self.config.model_aliases
284339
}
285340

packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,105 @@ def df_func(df: pd.DataFrame) -> pd.DataFrame:
467467
gen = _create_test_generator(name="result", generator_function=df_func)
468468
with pytest.raises(CustomColumnGenerationError, match="first parameter must be 'row', got 'df'"):
469469
gen.generate({"input": 1})
470+
471+
472+
# Async model bridge tests
473+
474+
475+
class TestAsyncBridgedModelFacade:
476+
"""Tests for _AsyncBridgedModelFacade proxy used by custom columns with model access."""
477+
478+
def test_proxy_transparent_in_sync_mode(self, stub_resource_provider, stub_model_facade) -> None:
479+
"""Proxy passes through generate(), forwards attributes, and is used by _build_models_dict."""
480+
from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade
481+
482+
@custom_column_generator(required_columns=["input"], model_aliases=["test-model"])
483+
def gen_with_model(row: dict, generator_params: SampleParams, models: dict) -> dict:
484+
row["result"] = "ok"
485+
return row
486+
487+
generator = _create_test_generator(
488+
name="result",
489+
generator_function=gen_with_model,
490+
generator_params=SampleParams(),
491+
resource_provider=stub_resource_provider,
492+
)
493+
494+
models = generator._build_models_dict()
495+
proxy = models["test-model"]
496+
assert isinstance(proxy, _AsyncBridgedModelFacade)
497+
498+
# generate() passes through to the underlying facade (positional and keyword args)
499+
result, _ = proxy.generate("test", parser=str)
500+
assert result == "Generated summary text"
501+
stub_model_facade.generate.assert_called_once_with("test", parser=str)
502+
503+
# Other attributes are forwarded
504+
assert proxy.model_alias == "test_model"
505+
506+
def test_bridges_to_agenerate_on_sync_client_error(self) -> None:
507+
"""When sync generate() fails with an async/sync error, falls back to agenerate()."""
508+
import asyncio
509+
import threading
510+
from unittest.mock import patch
511+
512+
from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade
513+
514+
facade = Mock()
515+
facade.generate.side_effect = RuntimeError("Sync methods are not available on an async-mode HttpModelClient.")
516+
517+
async def fake_agenerate(*args: Any, **kwargs: Any) -> tuple:
518+
return ("async_result", list(args))
519+
520+
facade.agenerate = fake_agenerate
521+
proxy = _AsyncBridgedModelFacade(facade)
522+
523+
engine_loop = asyncio.new_event_loop()
524+
engine_thread = threading.Thread(target=engine_loop.run_forever, daemon=True)
525+
engine_thread.start()
526+
527+
try:
528+
with patch(
529+
"data_designer.engine.dataset_builders.utils.async_concurrency.ensure_async_engine_loop",
530+
return_value=engine_loop,
531+
):
532+
# Positional prompt arg is forwarded to agenerate
533+
result = proxy.generate("hello", parser=str)
534+
assert result == ("async_result", ["hello"])
535+
finally:
536+
engine_loop.call_soon_threadsafe(engine_loop.stop)
537+
engine_thread.join(timeout=5)
538+
539+
def test_non_client_mode_errors_propagate(self) -> None:
540+
"""Only the specific HttpModelClient sync-mode RuntimeError triggers bridging."""
541+
from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade
542+
543+
# ValueError - different type entirely
544+
facade = Mock()
545+
facade.generate.side_effect = ValueError("invalid prompt format")
546+
proxy = _AsyncBridgedModelFacade(facade)
547+
with pytest.raises(ValueError, match="invalid prompt format"):
548+
proxy.generate(prompt="hello")
549+
550+
# RuntimeError with a different message - same type, not caught
551+
facade = Mock()
552+
facade.generate.side_effect = RuntimeError("connection timed out for async request")
553+
proxy = _AsyncBridgedModelFacade(facade)
554+
with pytest.raises(RuntimeError, match="connection timed out"):
555+
proxy.generate(prompt="hello")
556+
557+
def test_deadlock_guard_on_event_loop(self) -> None:
558+
"""Raises a clear error instead of deadlocking when called from the event loop."""
559+
import asyncio
560+
561+
from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade
562+
563+
facade = Mock()
564+
facade.generate.side_effect = RuntimeError("Sync methods are not available on an async-mode HttpModelClient.")
565+
proxy = _AsyncBridgedModelFacade(facade)
566+
567+
async def call_from_loop() -> None:
568+
proxy.generate(prompt="hello")
569+
570+
with pytest.raises(RuntimeError, match="Use 'await model.agenerate\\(\\)'"):
571+
asyncio.run(call_from_loop())

0 commit comments

Comments
 (0)