@@ -70,7 +70,7 @@ def __init__(
7070 output_buffer : Optional [MicroBatchDataBuffer ] = None ,
7171 rollout_n : int = 0 ,
7272 mode : str = "pointwise" ,
73- in_group_minibatch_size : int = 0 , # for one sample, how many runs to execute at the same time
73+ in_group_minibatch_size : Optional [ int ] = None , # for one sample, how many runs to execute at the same time
7474 evaluation_test_kwargs : Dict [str , Any ] = {},
7575 ):
7676 self .rollout_processor = rollout_processor
@@ -94,6 +94,11 @@ def __init__(
9494 self .background_tasks = set () # run evaluations in the background asynchronously
9595
9696 self .rollout_n = rollout_n
97+ if in_group_minibatch_size is None :
98+ if ENABLE_SPECULATION :
99+ in_group_minibatch_size = rollout_n // 2
100+ else :
101+ in_group_minibatch_size = rollout_n
97102 self .in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n
98103 self .evaluation_test_kwargs = evaluation_test_kwargs
99104
@@ -108,8 +113,7 @@ def __init__(
108113 # Track active evaluations
109114 self .active_evals : int = 0
110115 self .active_evals_lock = asyncio .Lock ()
111-
112- # Per-sample state for streaming scheduling
116+
113117 self .sample_states : Dict [int , SampleState ] = {}
114118
115119 async def schedule_dataset (
@@ -504,7 +508,6 @@ async def execute_priority_rollouts(
504508 max_concurrent_evaluations = max_concurrent_evaluations ,
505509 rollout_n = num_runs ,
506510 mode = mode ,
507- in_group_minibatch_size = (num_runs // 2 ),
508511 evaluation_test_kwargs = evaluation_test_kwargs ,
509512 )
510513 return await scheduler .run (dataset , num_runs , config )
0 commit comments