Skip to content

Commit 498e627

Browse files
authored
feat: Expose on_batch_complete via create method (#663)
1 parent 6055290 commit 498e627

2 files changed

Lines changed: 52 additions & 1 deletion

File tree

packages/data-designer/src/data_designer/interface/data_designer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import asyncio
77
import logging
88
import warnings
9+
from collections.abc import Callable
910
from pathlib import Path
1011
from typing import TYPE_CHECKING
1112

@@ -247,6 +248,7 @@ def create(
247248
dataset_name: str = "dataset",
248249
resume: ResumeMode = ResumeMode.NEVER,
249250
artifact_path: Path | str | None = None,
251+
on_batch_complete: Callable[[Path], None] | None = None,
250252
) -> DatasetCreationResults:
251253
"""Create dataset and save results to the local artifact storage.
252254
@@ -280,6 +282,13 @@ def create(
280282
discarded before generation continues.
281283
artifact_path: Optional artifact root for this create call. Defaults
282284
to the path configured on this DataDesigner instance.
285+
on_batch_complete: Optional callback called with the completed batch artifact path after
286+
each batch is written. Useful for incremental workflows such as uploading each batch
287+
to remote storage, updating an external run monitor, or triggering downstream processing
288+
before the full dataset has finished. The callback runs synchronously in the generation
289+
path, so it is recommended to keep it lightweight or delegate slow work to a queue, e.g.
290+
``on_batch_complete=lambda path: queue_upload(path)``. Callback exceptions abort the run
291+
and are wrapped as ``DataDesignerGenerationError``.
283292
284293
Returns:
285294
DatasetCreationResults object with methods for loading the generated dataset,
@@ -314,7 +323,7 @@ def create(
314323
raise DataDesignerGenerationError(f"🛑 Error generating dataset: {e}") from e
315324

316325
try:
317-
builder.build(num_records=num_records, resume=resume)
326+
builder.build(num_records=num_records, on_batch_complete=on_batch_complete, resume=resume)
318327
except DeprecationWarning:
319328
raise
320329
except Exception as e:

packages/data-designer/tests/interface/test_data_designer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,48 @@ def test_run_config_normalizes_error_rate_when_disabled(stub_artifact_path, stub
785785
assert data_designer.run_config.shutdown_error_rate == 1.0
786786

787787

788+
def test_create_forwards_on_batch_complete_callback(
789+
stub_artifact_path: Path,
790+
stub_model_providers: list[ModelProvider],
791+
stub_sampler_only_config_builder: DataDesignerConfigBuilder,
792+
stub_managed_assets_path: Path,
793+
) -> None:
794+
data_designer = DataDesigner(
795+
artifact_path=stub_artifact_path,
796+
model_providers=stub_model_providers,
797+
secret_resolver=PlaintextResolver(),
798+
managed_assets_path=stub_managed_assets_path,
799+
)
800+
801+
def on_batch_complete(path: Path) -> None:
802+
del path
803+
804+
with (
805+
patch.object(data_designer, "_create_resource_provider") as mock_resource_provider_method,
806+
patch.object(data_designer, "_create_dataset_builder") as mock_builder_method,
807+
patch.object(data_designer, "_create_dataset_profiler") as mock_profiler_method,
808+
):
809+
mock_resource_provider = MagicMock()
810+
mock_resource_provider.get_dataset_metadata.return_value = {}
811+
mock_resource_provider_method.return_value = mock_resource_provider
812+
813+
mock_builder = MagicMock()
814+
mock_builder.build.return_value = None
815+
mock_builder.task_traces = []
816+
mock_builder.artifact_storage.load_dataset_with_dropped_columns.return_value = lazy.pd.DataFrame({"col": [1]})
817+
mock_builder_method.return_value = mock_builder
818+
819+
mock_profiler = MagicMock()
820+
mock_profiler.profile_dataset.return_value = None
821+
mock_profiler_method.return_value = mock_profiler
822+
823+
data_designer.create(stub_sampler_only_config_builder, num_records=1, on_batch_complete=on_batch_complete)
824+
825+
_, build_kwargs = mock_builder.build.call_args
826+
assert build_kwargs["num_records"] == 1
827+
assert build_kwargs["on_batch_complete"] is on_batch_complete
828+
829+
788830
def test_run_config_rejects_invalid_buffer_size() -> None:
789831
with pytest.raises(ValidationError, match="buffer_size"):
790832
RunConfig(buffer_size=0)

0 commit comments

Comments
 (0)