Skip to content

Commit 945adff

Browse files
EazyRealclaude
andcommitted
fix(opd): score teacher logprobs at rollout temperature, not 0
The on-policy-distillation teacher reward_func scored teacher log-probs via SGLang with a hardcoded `temperature: 0`. SGLang computes input_token_logprobs WITH temperature scaling (compute_temp_top_p_normalized_logprobs), and the student log-probs are temperature-scaled by rollout_temperature (get_responses). So when rollout_temperature != 1 the OPD reverse-KL (student - teacher) compares log-probs at different effective temperatures and is biased. Score the teacher at rollout_temperature so both sides of the KL match. No change at the default rollout_temperature=1.0. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent e46ca0a commit 945adff

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

slime/rollout/on_policy_distillation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ async def reward_func(args, sample, **kwargs):
1010
# "text": sample.prompt + sample.response,
1111
"input_ids": sample.tokens,
1212
"sampling_params": {
13-
"temperature": 0,
13+
# Score teacher log-probs at rollout_temperature: SGLang scales
14+
# input_token_logprobs by the sampling temperature, and the student
15+
# log-probs are temperature-scaled too (get_responses), so the OPD KL is
16+
# only consistent when both are at the same temperature.
17+
"temperature": getattr(args, "rollout_temperature", 1.0),
1418
"max_new_tokens": 0,
1519
"skip_special_tokens": False,
1620
},

0 commit comments

Comments
 (0)