Skip to content

Commit 8fff7c0

Browse files
feat: add async generator migration with symmetric bridging and statefulness (#378)
* feat: add async generator migration with symmetric bridging and statefulness - Symmetric generate/agenerate bridging in base ColumnGenerator - is_stateful property; SeedDatasetColumnGenerator declares True - Async wrappers for FromScratchColumnGenerator and ColumnGeneratorFullColumn - Native async paths for ImageCellGenerator and EmbeddingCellGenerator - CustomColumnGenerator.agenerate with full validation parity - Extract _postprocess_result for shared sync/async output validation * fix: avoid blocking caller on sync bridge timeout Use explicit pool lifecycle instead of context manager so that a TimeoutError releases the caller immediately via shutdown(wait=False) rather than blocking on pool.__exit__. * fix: widen agenerate type signature to match generate Add @overload declarations so the base agenerate accepts both dict and pd.DataFrame, mirroring the existing generate pattern. * fix: ensure pool shutdown on sync bridge success path The else clause after return was unreachable, leaking the ThreadPoolExecutor on every successful call. Capture the result first, shut down the pool, then return. * fix: use try/finally for pool shutdown in sync bridge Ensures ThreadPoolExecutor is shut down on all exit paths, including non-TimeoutError exceptions from the coroutine. * refactor: extract shared validation in ImageCellGenerator Move duplicated input validation and prompt rendering into _prepare_image_inputs, shared by generate and agenerate. * refactor: extract shared input prep in EmbeddingCellGenerator * address PR review feedback - add _is_overridden helper for symmetric generate/agenerate guards - move defensive .copy() into base agenerate, remove subclass overrides - re-raise as builtin TimeoutError for Python 3.10 compat - rename is_stateful to is_order_dependent with improved docstring - replace brittle .fget test with object.__new__ - add async tests for ImageCellGenerator and EmbeddingCellGenerator
1 parent 340087f commit 8fff7c0

7 files changed

Lines changed: 735 additions & 121 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@
44
from __future__ import annotations
55

66
import asyncio
7+
import concurrent.futures
78
import functools
89
import logging
910
from abc import ABC, abstractmethod
10-
from typing import TYPE_CHECKING, Any, overload
11+
from typing import TYPE_CHECKING, Any, Coroutine, TypeVar, overload
1112

1213
from data_designer.config.column_configs import GenerationStrategy
1314
from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT
1415
from data_designer.logging import LOG_DOUBLE_INDENT, LOG_INDENT
1516

17+
_T = TypeVar("_T")
18+
19+
_SYNC_BRIDGE_TIMEOUT = 300
20+
1621
if TYPE_CHECKING:
1722
import pandas as pd
1823

@@ -23,33 +28,84 @@
2328
logger = logging.getLogger(__name__)
2429

2530

31+
def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T:
32+
"""Run an async coroutine from sync context.
33+
34+
- No running event loop → ``asyncio.run(coro)``
35+
- Running event loop (e.g. notebook/service) → run in a background thread
36+
"""
37+
try:
38+
asyncio.get_running_loop()
39+
except RuntimeError:
40+
return asyncio.run(coro)
41+
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
42+
future = pool.submit(asyncio.run, coro)
43+
timed_out = False
44+
try:
45+
result = future.result(timeout=_SYNC_BRIDGE_TIMEOUT)
46+
except concurrent.futures.TimeoutError as exc:
47+
timed_out = True
48+
logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running")
49+
raise TimeoutError(f"_run_coroutine_sync timed out after {_SYNC_BRIDGE_TIMEOUT}s") from exc
50+
finally:
51+
pool.shutdown(wait=not timed_out, cancel_futures=timed_out)
52+
return result
53+
54+
2655
class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC):
2756
@property
2857
def can_generate_from_scratch(self) -> bool:
2958
return False
3059

60+
@property
61+
def is_order_dependent(self) -> bool:
62+
"""Whether this generator's output depends on prior row-group calls.
63+
64+
Example: SeedDatasetColumnGenerator tracks its position in the seed
65+
dataset, so row group N must complete before N+1 starts.
66+
"""
67+
return False
68+
69+
def _is_overridden(self, method_name: str) -> bool:
70+
"""Check if a subclass has overridden a base ColumnGenerator method."""
71+
return getattr(type(self), method_name) is not getattr(ColumnGenerator, method_name)
72+
3173
@staticmethod
3274
@abstractmethod
3375
def get_generation_strategy() -> GenerationStrategy: ...
3476

3577
@overload
36-
@abstractmethod
3778
def generate(self, data: dict) -> dict: ...
3879

3980
@overload
40-
@abstractmethod
4181
def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...
4282

43-
@abstractmethod
44-
def generate(self, data: DataT) -> DataT: ...
83+
def generate(self, data: DataT) -> DataT:
84+
"""Sync generate — overridden by most concrete generators.
85+
86+
Default bridges to ``agenerate()`` for async-first subclasses that only
87+
implement ``agenerate()``. Raises ``NotImplementedError`` if neither
88+
``generate()`` nor ``agenerate()`` is overridden.
89+
"""
90+
if not self._is_overridden("agenerate"):
91+
raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()")
92+
return _run_coroutine_sync(self.agenerate(data))
4593

46-
async def agenerate(self, data: dict) -> dict:
47-
"""Async fallback — delegates to sync generate via thread pool.
94+
@overload
95+
async def agenerate(self, data: dict) -> dict: ...
96+
97+
@overload
98+
async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: ...
99+
100+
async def agenerate(self, data: DataT) -> DataT:
101+
"""Async generate — delegates to sync ``generate()`` via thread pool.
48102
49103
Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion)
50104
should override this with a direct async implementation.
51105
"""
52-
return await asyncio.to_thread(self.generate, data)
106+
if not self._is_overridden("generate"):
107+
raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()")
108+
return await asyncio.to_thread(self.generate, data.copy())
53109

54110
def log_pre_generation(self) -> None:
55111
"""A shared method to log info before the generator's `generate` method is called.
@@ -68,6 +124,10 @@ def can_generate_from_scratch(self) -> bool:
68124
@abstractmethod
69125
def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...
70126

127+
async def agenerate_from_scratch(self, num_records: int) -> pd.DataFrame:
128+
"""Async wrapper — wraps sync ``generate_from_scratch()`` in a thread."""
129+
return await asyncio.to_thread(self.generate_from_scratch, num_records)
130+
71131

72132
class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC):
73133
@property

packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import asyncio
89
import inspect
910
import logging
1011
from typing import TYPE_CHECKING, Any
@@ -65,12 +66,57 @@ def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict
6566

6667
return self._generate(data, is_dataframe)
6768

69+
async def agenerate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict]:
70+
"""Async generate — branches on strategy and detects coroutine functions."""
71+
is_full_column = self.config.generation_strategy == GenerationStrategy.FULL_COLUMN
72+
if is_full_column:
73+
return await asyncio.to_thread(self.generate, data.copy())
74+
# The @custom_column_generator decorator wraps the user function in a sync
75+
# wrapper, so we must unwrap to detect async functions.
76+
fn_unwrapped = inspect.unwrap(self.config.generator_function)
77+
if asyncio.iscoroutinefunction(fn_unwrapped):
78+
missing = set(self.config.required_columns) - set(data.keys())
79+
if missing:
80+
raise CustomColumnGenerationError(
81+
f"Missing required columns for custom generator '{self.config.name}': {sorted(missing)}"
82+
)
83+
keys_before = set(data.keys())
84+
85+
try:
86+
result = await self._ainvoke_generator_function(data)
87+
except CustomColumnGenerationError:
88+
raise
89+
except Exception as e:
90+
logger.warning(
91+
f"⚠️ Custom generator function {self.config.generator_function.__name__!r} "
92+
f"failed for column '{self.config.name}'. This record will be skipped.\n{e}"
93+
)
94+
raise CustomColumnGenerationError(
95+
f"Custom generator function failed for column '{self.config.name}': {e}"
96+
) from e
97+
98+
return self._postprocess_result(result, is_dataframe=False, keys_before=keys_before)
99+
return await asyncio.to_thread(self.generate, data)
100+
101+
async def _ainvoke_generator_function(self, data: dict) -> dict | pd.DataFrame:
102+
"""Invoke an async user generator function with appropriate arguments.
103+
104+
The @custom_column_generator decorator's sync wrapper returns a coroutine
105+
when the original function is async, so we await the wrapper's return value.
106+
"""
107+
params = self._get_validated_params()
108+
fn = self.config.generator_function
109+
if len(params) == 1:
110+
return await fn(data)
111+
elif len(params) == 2:
112+
return await fn(data, self.config.generator_params)
113+
else:
114+
models = self._build_models_dict()
115+
return await fn(data, self.config.generator_params, models)
116+
68117
def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame | list[dict]:
69118
"""Unified generation logic for both strategies."""
70-
# Get columns/keys using unified accessor
71119
get_keys = (lambda d: set(d.columns)) if is_dataframe else (lambda d: set(d.keys()))
72-
expected_type = lazy.pd.DataFrame if is_dataframe else dict
73-
type_name = "DataFrame" if is_dataframe else "dict"
74120

75121
# Check required columns
76122
missing = set(self.config.required_columns) - get_keys(data)
@@ -96,6 +142,15 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.
96142
f"Custom generator function failed for column '{self.config.name}': {e}"
97143
) from e
98144

145+
return self._postprocess_result(result, is_dataframe, keys_before)
146+
147+
def _postprocess_result(
148+
self,
149+
result: dict | pd.DataFrame | list[dict],
150+
is_dataframe: bool,
151+
keys_before: set[str],
152+
) -> dict | pd.DataFrame | list[dict]:
153+
"""Validate type and output columns of a generation result."""
99154
# Cell-by-cell with allow_resize: accept dict or list[dict]
100155
if not is_dataframe and self.config.allow_resize:
101156
if isinstance(result, dict):
@@ -113,6 +168,8 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.
113168
)
114169

115170
# Validate return type for non-resize paths
171+
expected_type = lazy.pd.DataFrame if is_dataframe else dict
172+
type_name = "DataFrame" if is_dataframe else "dict"
116173
if not isinstance(result, expected_type):
117174
raise CustomColumnGenerationError(
118175
f"Custom generator for column '{self.config.name}' must return a {type_name}, "

packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,19 @@ class EmbeddingCellGenerator(ColumnGeneratorWithModel[EmbeddingColumnConfig]):
2727
def get_generation_strategy() -> GenerationStrategy:
2828
return GenerationStrategy.CELL_BY_CELL
2929

30-
def generate(self, data: dict) -> dict:
30+
def _prepare_embedding_inputs(self, data: dict) -> list[str]:
3131
deserialized_record = deserialize_json_values(data)
32-
input_texts = parse_list_string(deserialized_record[self.config.target_column])
32+
return parse_list_string(deserialized_record[self.config.target_column])
33+
34+
def generate(self, data: dict) -> dict:
35+
input_texts = self._prepare_embedding_inputs(data)
3336
embeddings = self.model.generate_text_embeddings(input_texts=input_texts)
3437
data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json")
3538
return data
39+
40+
async def agenerate(self, data: dict) -> dict:
41+
"""Native async generate using model.agenerate_text_embeddings."""
42+
input_texts = self._prepare_embedding_inputs(data)
43+
embeddings = await self.model.agenerate_text_embeddings(input_texts=input_texts)
44+
data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json")
45+
return data

packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import asyncio
67
from typing import TYPE_CHECKING
78

89
from data_designer.config.column_configs import ImageColumnConfig
@@ -31,46 +32,42 @@ def media_storage(self) -> MediaStorage:
3132
def get_generation_strategy() -> GenerationStrategy:
3233
return GenerationStrategy.CELL_BY_CELL
3334

34-
def generate(self, data: dict) -> dict:
35-
"""Generate image(s) and optionally save to disk.
36-
37-
Args:
38-
data: Record data
39-
40-
Returns:
41-
Record with image path(s) (create mode) or base64 data (preview mode) added
42-
"""
35+
def _prepare_image_inputs(self, data: dict) -> tuple[str, list[dict] | None]:
36+
"""Validate inputs and render prompt for image generation."""
4337
deserialized_record = deserialize_json_values(data)
44-
45-
# Validate required columns
4638
missing_columns = list(set(self.config.required_columns) - set(data.keys()))
4739
if len(missing_columns) > 0:
48-
error_msg = (
40+
raise ValueError(
4941
f"There was an error preparing the Jinja2 expression template. "
5042
f"The following columns {missing_columns} are missing!"
5143
)
52-
raise ValueError(error_msg)
53-
54-
# Render prompt template
5544
self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys()))
5645
prompt = self.render_template(deserialized_record)
57-
58-
# Validate prompt is non-empty
5946
if not prompt or not prompt.strip():
6047
raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty")
61-
62-
# Process multi-modal context if provided
6348
multi_modal_context = self._build_multi_modal_context(deserialized_record)
49+
return prompt, multi_modal_context
6450

65-
# Generate images (returns list of base64 strings)
51+
def generate(self, data: dict) -> dict:
52+
"""Generate image(s) and optionally save to disk."""
53+
prompt, multi_modal_context = self._prepare_image_inputs(data)
6654
base64_images = self.model.generate_image(prompt=prompt, multi_modal_context=multi_modal_context)
67-
68-
# Store via media storage (mode determines disk vs dataframe storage)
69-
# Use column name as subfolder to organize images
7055
results = [
7156
self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name)
7257
for base64_image in base64_images
7358
]
7459
data[self.config.name] = results
60+
return data
7561

62+
async def agenerate(self, data: dict) -> dict:
63+
"""Native async generate using model.agenerate_image."""
64+
prompt, multi_modal_context = self._prepare_image_inputs(data)
65+
base64_images = await self.model.agenerate_image(prompt=prompt, multi_modal_context=multi_modal_context)
66+
results = await asyncio.to_thread(
67+
lambda: [
68+
self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name)
69+
for base64_image in base64_images
70+
]
71+
)
72+
data[self.config.name] = results
7673
return data

packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu
2929
def get_generation_strategy() -> GenerationStrategy:
3030
return GenerationStrategy.FULL_COLUMN
3131

32+
@property
33+
def is_order_dependent(self) -> bool:
34+
return True
35+
3236
@property
3337
def num_records_sampled(self) -> int:
3438
return self._num_records_sampled

0 commit comments

Comments
 (0)