diff --git a/slime/rollout/on_policy_distillation.py b/slime/rollout/on_policy_distillation.py index 9190974345..eb52a0821d 100644 --- a/slime/rollout/on_policy_distillation.py +++ b/slime/rollout/on_policy_distillation.py @@ -10,7 +10,7 @@ async def reward_func(args, sample, **kwargs): # "text": sample.prompt + sample.response, "input_ids": sample.tokens, "sampling_params": { - "temperature": 0, + "temperature": args.rollout_temperature, "max_new_tokens": 0, "skip_special_tokens": False, },