@@ -785,6 +785,48 @@ def test_run_config_normalizes_error_rate_when_disabled(stub_artifact_path, stub
785785 assert data_designer .run_config .shutdown_error_rate == 1.0
786786
787787
788+ def test_create_forwards_on_batch_complete_callback (
789+ stub_artifact_path : Path ,
790+ stub_model_providers : list [ModelProvider ],
791+ stub_sampler_only_config_builder : DataDesignerConfigBuilder ,
792+ stub_managed_assets_path : Path ,
793+ ) -> None :
794+ data_designer = DataDesigner (
795+ artifact_path = stub_artifact_path ,
796+ model_providers = stub_model_providers ,
797+ secret_resolver = PlaintextResolver (),
798+ managed_assets_path = stub_managed_assets_path ,
799+ )
800+
801+ def on_batch_complete (path : Path ) -> None :
802+ del path
803+
804+ with (
805+ patch .object (data_designer , "_create_resource_provider" ) as mock_resource_provider_method ,
806+ patch .object (data_designer , "_create_dataset_builder" ) as mock_builder_method ,
807+ patch .object (data_designer , "_create_dataset_profiler" ) as mock_profiler_method ,
808+ ):
809+ mock_resource_provider = MagicMock ()
810+ mock_resource_provider .get_dataset_metadata .return_value = {}
811+ mock_resource_provider_method .return_value = mock_resource_provider
812+
813+ mock_builder = MagicMock ()
814+ mock_builder .build .return_value = None
815+ mock_builder .task_traces = []
816+ mock_builder .artifact_storage .load_dataset_with_dropped_columns .return_value = lazy .pd .DataFrame ({"col" : [1 ]})
817+ mock_builder_method .return_value = mock_builder
818+
819+ mock_profiler = MagicMock ()
820+ mock_profiler .profile_dataset .return_value = None
821+ mock_profiler_method .return_value = mock_profiler
822+
823+ data_designer .create (stub_sampler_only_config_builder , num_records = 1 , on_batch_complete = on_batch_complete )
824+
825+ _ , build_kwargs = mock_builder .build .call_args
826+ assert build_kwargs ["num_records" ] == 1
827+ assert build_kwargs ["on_batch_complete" ] is on_batch_complete
828+
829+
788830def test_run_config_rejects_invalid_buffer_size () -> None :
789831 with pytest .raises (ValidationError , match = "buffer_size" ):
790832 RunConfig (buffer_size = 0 )
0 commit comments