Skip to content

Commit ca2ea7f

Browse files
committed
fix(opd): use rollout temperature directly
1 parent efbb291 commit ca2ea7f

1 file changed

Lines changed: 1 addition & 5 deletions

File tree

slime/rollout/on_policy_distillation.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@ async def reward_func(args, sample, **kwargs):
1010
# "text": sample.prompt + sample.response,
1111
"input_ids": sample.tokens,
1212
"sampling_params": {
13-
# Score teacher log-probs at rollout_temperature: SGLang scales
14-
# input_token_logprobs by the sampling temperature, and the student
15-
# log-probs are temperature-scaled too (get_responses), so the OPD KL is
16-
# only consistent when both are at the same temperature.
17-
"temperature": getattr(args, "rollout_temperature", 1.0),
13+
"temperature": args.rollout_temperature,
1814
"max_new_tokens": 0,
1915
"skip_special_tokens": False,
2016
},

0 commit comments

Comments
 (0)