-
Notifications
You must be signed in to change notification settings - Fork 173
Expand file tree
/
Copy pathtest_async_builder_integration.py
More file actions
373 lines (292 loc) · 13.8 KB
/
test_async_builder_integration.py
File metadata and controls
373 lines (292 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import math
import warnings
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock
import pytest
import data_designer.engine.dataset_builders.dataset_builder as builder_mod
import data_designer.lazy_heavy_imports as lazy
from data_designer.config.column_configs import (
ExpressionColumnConfig,
GenerationStrategy,
LLMTextColumnConfig,
SamplerColumnConfig,
)
from data_designer.config.sampler_params import SamplerType
from data_designer.engine.column_generators.generators.base import (
ColumnGenerator,
ColumnGeneratorFullColumn,
FromScratchColumnGenerator,
)
from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler
from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder
from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta
from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph
from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager
from data_designer.engine.resources.resource_provider import ResourceProvider
MODEL_ALIAS = "stub"
# -- Mock generators for integration tests ------------------------------------
def _expr_config(name: str = "test") -> ExpressionColumnConfig:
return ExpressionColumnConfig(name=name, expr="{{ x }}", dtype="str")
def _mock_provider() -> MagicMock:
return MagicMock(spec=ResourceProvider)
class MockSeed(FromScratchColumnGenerator[ExpressionColumnConfig]):
@staticmethod
def get_generation_strategy() -> GenerationStrategy:
return GenerationStrategy.FULL_COLUMN
def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame:
return data
def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame:
return lazy.pd.DataFrame({"seed": list(range(num_records))})
class MockCell(ColumnGenerator[ExpressionColumnConfig]):
@staticmethod
def get_generation_strategy() -> GenerationStrategy:
return GenerationStrategy.CELL_BY_CELL
def generate(self, data: dict) -> dict:
data["cell_out"] = f"val_{data.get('seed', '?')}"
return data
class MockFullCol(ColumnGeneratorFullColumn[ExpressionColumnConfig]):
def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame:
data["expr_out"] = "computed"
return data
# -- allow_resize validation test ---------------------------------------------
@pytest.mark.parametrize(
"configs,expected",
[
pytest.param(
[Mock(name="col_a", allow_resize=True), Mock(name="col_b", allow_resize=False)],
False,
id="fallback_on_allow_resize",
),
pytest.param(
[Mock(name="col_a", allow_resize=False), Mock(name="col_b", allow_resize=False)],
True,
id="async_without_allow_resize",
),
],
)
def test_resolve_async_compatibility(configs: list[Mock], expected: bool) -> None:
"""allow_resize=True triggers auto-fallback to sync with a deprecation warning."""
builder = Mock(spec=DatasetBuilder)
builder.single_column_configs = configs
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = DatasetBuilder._resolve_async_compatibility(builder)
assert result is expected
if not expected:
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "allow_resize" in str(w[0].message)
# Regression for PR #594 review: the warning must attribute to the
# caller's frame (this test file), not to a ``data_designer.*`` library
# frame. Library-attributed ``DeprecationWarning`` entries fall under
# Python's default ``ignore::DeprecationWarning`` filter and are
# silenced. A regression to ``warnings.warn(..., stacklevel=N)`` would
# land somewhere inside the engine package and silently break the
# user-facing nudge.
assert w[0].filename == __file__
else:
assert len(w) == 0
# -- _build_async integration test with mock generators -----------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_build_async_end_to_end() -> None:
"""Test _build_async with mock generators produces correct results in buffer."""
provider = _mock_provider()
seed_gen = MockSeed(config=_expr_config("seed"), resource_provider=provider)
cell_gen = MockCell(config=_expr_config("cell_out"), resource_provider=provider)
expr_gen = MockFullCol(config=_expr_config("expr_out"), resource_provider=provider)
configs = [
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS),
ExpressionColumnConfig(name="expr_out", expr="{{ cell_out }}"),
]
strategies = {
"seed": GenerationStrategy.FULL_COLUMN,
"cell_out": GenerationStrategy.CELL_BY_CELL,
"expr_out": GenerationStrategy.FULL_COLUMN,
}
gen_map = {
"seed": seed_gen,
"cell_out": cell_gen,
"expr_out": expr_gen,
}
num_records = 4
buffer_size = 2
graph = ExecutionGraph.create(configs, strategies)
row_groups: list[tuple[int, int]] = []
remaining = num_records
rg_id = 0
while remaining > 0:
size = min(buffer_size, remaining)
row_groups.append((rg_id, size))
remaining -= size
rg_id += 1
tracker = CompletionTracker.with_graph(graph, row_groups)
storage = MagicMock()
storage.dataset_name = "test"
storage.get_file_paths.return_value = {}
storage.write_batch_to_parquet_file.return_value = "/fake.parquet"
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
buffer_manager = RowGroupBufferManager(storage)
finalized: list[int] = []
def finalize_row_group(rg_id: int) -> None:
buffer_manager.checkpoint_row_group(rg_id)
finalized.append(rg_id)
scheduler = AsyncTaskScheduler(
generators=gen_map,
graph=graph,
tracker=tracker,
row_groups=row_groups,
buffer_manager=buffer_manager,
on_finalize_row_group=finalize_row_group,
)
await scheduler.run()
# Both row groups should be finalized
assert sorted(finalized) == [0, 1]
assert buffer_manager.actual_num_records == 4
# All columns should be complete
all_cols = ["seed", "cell_out", "expr_out"]
assert tracker.is_row_group_complete(0, 2, all_cols)
assert tracker.is_row_group_complete(1, 2, all_cols)
def test_prepare_async_run_enables_request_pressure_advisory(monkeypatch: pytest.MonkeyPatch) -> None:
captured_kwargs: dict[str, object] = {}
class _SpyScheduler:
def __init__(self, **kwargs: object) -> None:
captured_kwargs.update(kwargs)
monkeypatch.setattr(builder_mod, "AsyncTaskScheduler", _SpyScheduler)
request_admission = object()
model_registry = MagicMock()
model_registry.get_aggregate_max_parallel_requests.side_effect = AssertionError(
"model task admission should follow max_in_flight_tasks directly"
)
model_registry.request_admission = request_admission
provider = SimpleNamespace(
model_registry=model_registry,
run_config=SimpleNamespace(max_in_flight_tasks=64, progress_interval=5.0, progress_bar=False),
)
processor_runner = MagicMock()
processor_runner.has_processors_for.return_value = False
config = SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]})
builder = SimpleNamespace(
_column_configs=[config],
_processor_runner=processor_runner,
artifact_storage=MagicMock(),
_resource_provider=provider,
)
generator = MockSeed(config=_expr_config("seed"), resource_provider=provider)
DatasetBuilder._prepare_async_run(builder, [generator], num_records=1, buffer_size=1)
assert captured_kwargs["request_pressure_provider"] is request_admission
assert captured_kwargs["request_pressure_advisory"] is True
assert captured_kwargs["max_in_flight_tasks"] == 64
assert captured_kwargs["max_model_task_admission"] == 64
# -- Test that existing sync path is unaffected --------------------------------
def test_sync_path_unaffected_by_async_engine_flag() -> None:
"""DATA_DESIGNER_ASYNC_ENGINE=0 keeps the sync path unchanged."""
import data_designer.engine.dataset_builders.dataset_builder as builder_mod
assert hasattr(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE")
assert isinstance(builder_mod.DATA_DESIGNER_ASYNC_ENGINE, bool)
# -- Test execution graph integration with real column configs -----------------
def test_execution_graph_from_real_configs() -> None:
"""Build execution graph from real column config objects."""
configs = [
SamplerColumnConfig(name="id", sampler_type=SamplerType.UUID, params={}),
LLMTextColumnConfig(name="question", prompt="{{ id }}", model_alias=MODEL_ALIAS),
LLMTextColumnConfig(name="answer", prompt="{{ question }}", model_alias=MODEL_ALIAS),
ExpressionColumnConfig(name="combined", expr="{{ question }} {{ answer }}"),
]
strategies = {
"id": GenerationStrategy.FULL_COLUMN,
"question": GenerationStrategy.CELL_BY_CELL,
"answer": GenerationStrategy.CELL_BY_CELL,
"combined": GenerationStrategy.FULL_COLUMN,
}
graph = ExecutionGraph.create(configs, strategies)
order = graph.get_topological_order()
idx = {col: i for i, col in enumerate(order)}
assert idx["id"] < idx["question"]
assert idx["question"] < idx["answer"]
assert idx["answer"] < idx["combined"]
# Task counts
counts = graph.compute_task_count(num_records=10, buffer_size=3)
assert counts["id"] == math.ceil(10 / 3)
assert counts["question"] == 10
assert counts["answer"] == 10
assert counts["combined"] == math.ceil(10 / 3)
# -- Test checkpoint correctness -----------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_checkpoint_produces_correct_parquet_calls() -> None:
"""Verify checkpoint writes parquet for each row group."""
provider = _mock_provider()
seed_gen = MockSeed(config=_expr_config("seed"), resource_provider=provider)
configs = [SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["X"]})]
strategies = {"seed": GenerationStrategy.FULL_COLUMN}
gen_map = {"seed": seed_gen}
graph = ExecutionGraph.create(configs, strategies)
row_groups = [(0, 3), (1, 2)]
tracker = CompletionTracker.with_graph(graph, row_groups)
storage = MagicMock()
storage.dataset_name = "test"
storage.get_file_paths.return_value = {}
storage.write_batch_to_parquet_file.return_value = "/fake.parquet"
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
buffer_manager = RowGroupBufferManager(storage)
scheduler = AsyncTaskScheduler(
generators=gen_map,
graph=graph,
tracker=tracker,
row_groups=row_groups,
buffer_manager=buffer_manager,
on_finalize_row_group=lambda rg_id: buffer_manager.checkpoint_row_group(rg_id),
)
await scheduler.run()
# Two row groups → two write_batch_to_parquet_file calls
assert storage.write_batch_to_parquet_file.call_count == 2
assert storage.move_partial_result_to_final_file_path.call_count == 2
assert buffer_manager.actual_num_records == 5
# -- Partial completion warning ------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_dropped_rows_reduce_actual_record_count() -> None:
"""When all rows in a row group are dropped, actual_num_records reflects the shortfall
and write_metadata records the correct actual vs target counts."""
provider = _mock_provider()
seed_gen = MockSeed(config=_expr_config("seed"), resource_provider=provider)
configs = [SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["X"]})]
strategies = {"seed": GenerationStrategy.FULL_COLUMN}
gen_map = {"seed": seed_gen}
graph = ExecutionGraph.create(configs, strategies)
num_records = 6
row_groups = [(0, 3), (1, 3)]
tracker = CompletionTracker.with_graph(graph, row_groups)
storage = MagicMock()
storage.dataset_name = "test"
storage.get_file_paths.return_value = {}
storage.write_batch_to_parquet_file.return_value = "/fake.parquet"
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
buffer_manager = RowGroupBufferManager(storage)
def drop_all_in_rg1(rg_id: int, rg_size: int) -> FrontierDelta:
deltas: list[FrontierDelta] = []
if rg_id == 1:
for ri in range(rg_size):
deltas.append(tracker.drop_row(rg_id, ri))
buffer_manager.drop_row(rg_id, ri)
return FrontierDelta(
added=tuple(task for delta in deltas for task in delta.added),
removed=tuple(task for delta in deltas for task in delta.removed),
)
scheduler = AsyncTaskScheduler(
generators=gen_map,
graph=graph,
tracker=tracker,
row_groups=row_groups,
buffer_manager=buffer_manager,
on_finalize_row_group=lambda rg_id: buffer_manager.checkpoint_row_group(rg_id),
on_seeds_complete=drop_all_in_rg1,
)
await scheduler.run()
assert buffer_manager.actual_num_records < num_records
buffer_manager.write_metadata(target_num_records=num_records, buffer_size=3)
written = storage.write_metadata.call_args[0][0]
assert written["actual_num_records"] == buffer_manager.actual_num_records
assert written["target_num_records"] == num_records