Skip to content

Commit 1dd4169

Browse files
committed
add
1 parent c9344c3 commit 1dd4169

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

eval_protocol/pytest/priority_scheduler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
output_buffer: Optional[MicroBatchDataBuffer] = None,
7171
rollout_n: int = 0,
7272
mode: str = "pointwise",
73-
in_group_minibatch_size: int = 0, # for one sample, how many runs to execute at the same time
73+
in_group_minibatch_size: Optional[int] = None, # for one sample, how many runs to execute at the same time
7474
evaluation_test_kwargs: Dict[str, Any] = {},
7575
):
7676
self.rollout_processor = rollout_processor
@@ -94,6 +94,11 @@ def __init__(
9494
self.background_tasks = set() # run evaluations in the background asynchronously
9595

9696
self.rollout_n = rollout_n
97+
if in_group_minibatch_size is None:
98+
if ENABLE_SPECULATION:
99+
in_group_minibatch_size = rollout_n // 2
100+
else:
101+
in_group_minibatch_size = rollout_n
97102
self.in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n
98103
self.evaluation_test_kwargs = evaluation_test_kwargs
99104

@@ -108,8 +113,7 @@ def __init__(
108113
# Track active evaluations
109114
self.active_evals: int = 0
110115
self.active_evals_lock = asyncio.Lock()
111-
112-
# Per-sample state for streaming scheduling
116+
113117
self.sample_states: Dict[int, SampleState] = {}
114118

115119
async def schedule_dataset(
@@ -504,7 +508,6 @@ async def execute_priority_rollouts(
504508
max_concurrent_evaluations=max_concurrent_evaluations,
505509
rollout_n=num_runs,
506510
mode=mode,
507-
in_group_minibatch_size=(num_runs // 2),
508511
evaluation_test_kwargs=evaluation_test_kwargs,
509512
)
510513
return await scheduler.run(dataset, num_runs, config)

0 commit comments

Comments
 (0)