|
| 1 | +# On-Policy Distillation |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +On-policy distillation trains the student using teacher guidance on trajectories sampled from its own policy, reducing distribution mismatch and improving stability. Combined with reinforcement learning, it lets the student **imitate the teacher while exploring simultaneously**. |
| 6 | + |
| 7 | +**AReaL** previously supported RL for post-training. With this implementation, it now also supports **on-policy knowledge distillation** and the **combined KDRL framework**, enabling the student to learn from a teacher while exploring via RL on the same on-policy trajectories, improving both efficiency and stability. |
| 8 | + |
| 9 | +## The Core Concept |
| 10 | + |
| 11 | +Knowledge distillation aims to train the student policy $\pi_\theta$ to mimic the behavior of a more powerful teacher $\pi_T$. The choice of divergence measure and sampling distribution significantly impacts the student's final performance and exposure bias. |
| 12 | + |
| 13 | +### Supervised Fine-Tuning (Forward KL): |
| 14 | + |
| 15 | +A simple yet effective method is to maximize the log-likelihood on data generated by the teacher, known as supervised fine-tuning (SFT). This is equivalent to minimizing the Forward KL divergence between $\pi_T$ and $\pi_\theta$: |
| 16 | +$$\arg \min_{\theta} D_{KL}(\pi_T \parallel \pi_\theta) = \arg \max_{\theta} \mathbb{E}_{q \sim Q, o \sim \pi_T(\cdot|q)} [\log \pi_\theta(o|q)]$$ |
| 17 | + |
| 18 | + |
| 19 | +### On-Policy Distillation (Reverse KL): |
| 20 | + |
| 21 | +While SFT is efficient, training on off-policy data induces exposure bias: a mismatch between training on teacher-generated prefixes and inference on self-generated prefixes. This is especially severe for reasoning LLMs with long response chains. To alleviate this, we can train on self-generated trajectories, which is equivalent to minimizing the Reverse KL divergence (RKL) [1]: |
| 22 | +$$\arg \min_{\theta} D_{KL}(\pi_\theta \parallel \pi_T) = \arg \max_{\theta} \mathbb{E}_{q \sim Q, o \sim \pi_\theta(\cdot|q)} \left[ \log \frac{\pi_T(o|q)}{\pi_\theta(o|q)} \right]$$ |
| 23 | + |
| 24 | +Minimizing RKL is equivalent to REINFORCE where the "reward" is the log-ratio of teacher to student probabilities. By adopting the GRPO framework, we optimize [1]: |
| 25 | + |
| 26 | +$$J_{RKL}(\theta) = \mathbb{E}_{q, \{o_i\} \sim \pi_{\theta_{old}}} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \frac{\pi_\theta(o_{i,t})}{\pi_{\theta_{old}}(o_{i,t})} R_{i,t} \right]$$ |
| 27 | + |
| 28 | +where the reward $R_{i,t} = \log \pi_T(o_{i,t}) - \log \pi_\theta(o_{i,t})$. This encourages the student to increase the probability of tokens the teacher prefers and suppress those it deems unlikely. |
| 29 | + |
| 30 | +- Implementation Detail: |
| 31 | +During pure KD, we need to set `rl_loss_weight` to 0, so the implementation estimates the RKL gradient using importance sampling. The code calculates the reward as teacher_logp - logprobs ($R_{i,t}$) and applies a negative coefficient to the loss to perform minimization (check `areal/trainer/ppo/actor.py`). |
| 32 | + |
| 33 | + |
| 34 | +### Combination of GRPO and KD |
| 35 | +We implemented KD+RL approach using a Joint Loss strategy. |
| 36 | + |
| 37 | +#### Joint Loss: |
| 38 | +This strategy augments the GRPO objective with an auxiliary KL loss term. To maintain consistency with the on-policy nature of GRPO, it utilizes the Reverse KL (RKL) [1]: |
| 39 | +$$J_{KDRL}(\theta) = J_{GRPO}(\theta) - \beta D_{KL}(\pi_\theta \parallel \pi_T) \tag{8}$$ |
| 40 | + |
| 41 | +The gradient $\nabla_\theta J_{KDRL}(\theta)$ provides an unbiased estimate of $\nabla_\theta J_{GRPO}( \theta) + \beta \cdot \nabla_\theta J_{RKL}(\theta)$. |
| 42 | + |
| 43 | +- Implementation Detail: In the joint loss case (`rl_loss_weight` > 0), the RKL is treated as a direct penalty. Minimizing the term `logprobs - teacher_logp` is mathematically equivalent to minimizing the Reverse KL objective $D_{KL}(\pi_\theta \parallel \pi_T)$ when sampling from the student distribution $\pi_\theta$. In the code, this is implemented as: |
| 44 | +`loss = rl_loss_weight * loss + distill_loss_weight * rkl_penalty` |
| 45 | + |
| 46 | + |
| 47 | + |
| 48 | + |
| 49 | +## Running the example |
| 50 | + |
| 51 | +Need to add teacher configuration to your yaml: |
| 52 | + |
| 53 | +```yaml |
| 54 | +teacher: |
| 55 | + allocation_mode: d1p1t4 |
| 56 | + rl_loss_weight: 1.0 |
| 57 | + distill_loss_weight: 0.005 |
| 58 | + experiment_name: ${experiment_name} |
| 59 | + trial_name: ${trial_name} |
| 60 | + path: Qwen/Qwen3-32B |
| 61 | + init_from_scratch: false |
| 62 | + disable_dropout: true |
| 63 | + dtype: ${actor.dtype} |
| 64 | + mb_spec: |
| 65 | + max_tokens_per_mb: 10240 |
| 66 | + optimizer: null |
| 67 | + scheduling_spec: ${actor.scheduling_spec} |
| 68 | +``` |
| 69 | +
|
| 70 | +Example command using local scheduler: |
| 71 | +
|
| 72 | +```bash |
| 73 | +python3 examples/math/gsm8k_rl.py --config examples/distillation/gsm8k_grpo_distill.yaml scheduler.type=local experiment_name=gsm8k-grpo-distillation trial_name=trial0 |
| 74 | +``` |
| 75 | + |
| 76 | +## Result |
| 77 | + |
| 78 | +On-policy knowledge distillation + RL reward plot for Qwen2.5-14B-Instruct (teacher) and Qwen3-0.6B (student), trained using FSDP and vLLM. |
| 79 | + |
| 80 | + |
| 81 | + |
| 82 | +## References |
| 83 | + |
| 84 | +[1] Xu H, Zhu Q, Deng H, Li J, Hou L, Wang Y, Shang L, Xu R, Mi F. Kdrl: Post-training reasoning llms via unified knowledge distillation and reinforcement learning. [KDRL paper link](https://arxiv.org/pdf/2506.02208) |
0 commit comments