Skip to content

Commit 3fb99a0

Browse files
HwVanICIgemini-code-assist[bot]
authored andcommitted
[Feat] Add on-policy distillation support (#964)
* [feat] Add on-policy knowledge distillation support * Correct file extension * Delete comment on distill_loss_weight Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * explain RKL and Joint loss * Add reference to global README * teacher config refactoring * refactor: remove redundant stats_tracker denominator logging for KD loss --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent bf5fa67 commit 3fb99a0

7 files changed

Lines changed: 403 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ All RL algorithms support both asynchronous and synchronous versions by setting
178178
| **M2PO** | [📖 Docs](docs/algorithms/m2po.md) | [📄 Paper](https://arxiv.org/abs/2510.01161) | [🔗 GSM8K Example](examples/math/gsm8k_m2po.yaml) |
179179
| **RLHF Reward Modeling** | - | - | [🔗 RLHF Example](examples/alignment/) |
180180
| **SFT** | - | - | [🔗 GSM8K Example](examples/math/gsm8k_sft.py) |
181+
| **Distillation** | [📖 Docs](docs/en/algorithms/distillation.md) | [📄 Paper](https://arxiv.org/pdf/2506.02208) | [🔗 GSM8K Example](examples/distillation/gsm8k_grpo_distill.yaml) |
181182

182183
### Models
183184

areal/api/cli_args.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,6 +2145,23 @@ class RWConfig(BaseExperimentConfig):
21452145
actor: TrainEngineConfig = field(default_factory=TrainEngineConfig)
21462146

21472147

2148+
@dataclass
2149+
class TeacherConfig(PPOActorConfig):
2150+
allocation_mode: str = field(
2151+
default="",
2152+
metadata={"help": "Pattern-based GPU parallel strategy allocation mode. "},
2153+
)
2154+
rl_loss_weight: float = field(
2155+
default=1.0,
2156+
metadata={"help": "RL loss weight"},
2157+
)
2158+
2159+
distill_loss_weight: float = field(
2160+
default=0.005,
2161+
metadata={"help": "Distillation loss weight"},
2162+
)
2163+
2164+
21482165
@dataclass
21492166
class PPOConfig(BaseExperimentConfig):
21502167
"""Configuration for Proximal Policy Optimization (PPO) reinforcement learning experiments."""
@@ -2162,6 +2179,17 @@ class PPOConfig(BaseExperimentConfig):
21622179
actor: PPOActorConfig = field(default_factory=PPOActorConfig)
21632180
ref: PPOActorConfig | None = field(default=None)
21642181
critic: PPOCriticConfig | None = field(default=None)
2182+
teacher: TeacherConfig | None = field(
2183+
default=None,
2184+
metadata={
2185+
"help": (
2186+
"Optional teacher model configuration used for on-policy "
2187+
"distillation during PPO training. If provided, the actor "
2188+
"may be trained to match the teacher in addition to the "
2189+
"standard PPO objective."
2190+
)
2191+
},
2192+
)
21652193
dynamic_bs: bool = field(
21662194
default=False,
21672195
metadata={

areal/trainer/ppo/actor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,38 @@ def grpo_loss_fn(
430430
behave_imp_weight_mode=behave_imp_weight_mode,
431431
)
432432

433+
# Joint Distillation KL Loss
434+
teacher_logp = input_data.get("teacher_logp")
435+
rkl_stat = None
436+
if teacher_logp is not None:
437+
# Coefficients for RL and Knowledge Distillation
438+
rl_loss_weight = input_data.get("rl_loss_weight", 1.0)
439+
distill_loss_weight = input_data.get("distill_loss_weight", 0.005)
440+
441+
teacher_logp = (
442+
teacher_logp.detach()
443+
) # detach to prevent gradient backprop to teacher
444+
445+
if rl_loss_weight == 0:
446+
# Pure KD using reverse KL (importance-sampling)
447+
rkl_reward = teacher_logp - logprobs.detach()
448+
importance_weight = torch.exp(logprobs - old_logp)
449+
450+
rkl_weighted_term = importance_weight * rkl_reward * loss_mask
451+
452+
kd_coef = -1 * distill_loss_weight
453+
loss = kd_coef * rkl_weighted_term.sum() / loss_mask.sum().clamp(min=1)
454+
455+
rkl_stat = -1 * rkl_weighted_term
456+
else:
457+
# KDRL: Knowledge Distillation + Reinforcement Learning (joint loss)
458+
rkl_penalty_per_token = (logprobs - teacher_logp) * loss_mask
459+
rkl_penalty = rkl_penalty_per_token.sum() / loss_mask.sum().clamp(min=1)
460+
461+
loss = rl_loss_weight * loss + distill_loss_weight * rkl_penalty
462+
463+
rkl_stat = rkl_penalty_per_token
464+
433465
# Log training statistics
434466
stats_tracker.denominator(
435467
# NOTE: n_tokens must have shape [batch, seq] to match vocab stats.
@@ -442,6 +474,12 @@ def grpo_loss_fn(
442474
dual_clipped_tokens=stat["dual_clip_mask"],
443475
)
444476

477+
if rkl_stat is not None:
478+
stats_tracker.stat(
479+
rkl_loss=rkl_stat,
480+
denominator="n_valid_tokens",
481+
)
482+
445483
stats_tracker.stat(
446484
importance_weight=stat["importance_weight"],
447485
approx_kl=stat["approx_kl"],

areal/trainer/rl_trainer.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,19 @@ def __init__(
184184
if self.ref is not None:
185185
self.ref.initialize(**engine_init_kwargs, role="ref")
186186

187+
self.teacher = None
188+
if config.teacher is not None:
189+
self.teacher = self._create_teacher(config.teacher)
190+
teacher_allocation_mode = AllocationMode.from_str(
191+
config.teacher.allocation_mode
192+
)
193+
teacher_init_kwargs = {
194+
"addr": None,
195+
"ft_spec": ft_spec,
196+
"alloc_mode": teacher_allocation_mode,
197+
}
198+
self.teacher.initialize(**teacher_init_kwargs, role="teacher")
199+
187200
# Save initial LoRA weights if enabled (for inference server pre-loading)
188201
initial_lora_path = self._save_initial_lora_weights()
189202

@@ -372,6 +385,24 @@ def train(
372385
rollout_batch["ref_logp"] = self.ref.compute_logp(rollout_batch)
373386
self.ref.get_device_stats().log("ref logp")
374387

388+
if self.teacher is not None:
389+
with (
390+
stats_tracker.record_timing("teacher_logp"),
391+
perf_tracer.trace_scope(
392+
"train.teacher_logp",
393+
category=Category.COMPUTE,
394+
args={"global_step": global_step},
395+
),
396+
):
397+
rollout_batch["teacher_logp"] = self.teacher.compute_logp(
398+
rollout_batch
399+
)
400+
rollout_batch["rl_loss_weight"] = self.config.teacher.rl_loss_weight
401+
rollout_batch["distill_loss_weight"] = (
402+
self.config.teacher.distill_loss_weight
403+
)
404+
self.teacher.get_device_stats().log("teacher logp")
405+
375406
with (
376407
stats_tracker.record_timing("compute_advantage"),
377408
perf_tracer.trace_scope(
@@ -664,6 +695,34 @@ def _create_critic(
664695
critic.create_process_group(parallel_strategy=self.allocation_mode.train)
665696
return critic
666697

698+
def _create_teacher(self, teacher_config):
699+
allocation_mode = AllocationMode.from_str(teacher_config.allocation_mode)
700+
701+
if allocation_mode.train_backend == "fsdp":
702+
from areal.engine.fsdp_engine import FSDPPPOActor
703+
704+
actor_cls = FSDPPPOActor
705+
elif allocation_mode.train_backend == "megatron":
706+
from areal.engine.megatron_engine import MegatronPPOActor
707+
708+
actor_cls = MegatronPPOActor
709+
elif allocation_mode.train_backend == "archon":
710+
from areal.experimental.engine.archon_engine import ArchonPPOActor
711+
712+
actor_cls = ArchonPPOActor
713+
else:
714+
raise ValueError(
715+
f"Invalid backend: {allocation_mode.train_backend}, expected fsdp, megatron, or archon"
716+
)
717+
718+
if is_single_controller():
719+
teacher = actor_cls.as_controller(teacher_config, self.scheduler)
720+
else:
721+
teacher = actor_cls(config=teacher_config)
722+
723+
teacher.create_process_group(parallel_strategy=allocation_mode.train)
724+
return teacher
725+
667726
def _init_rollout(
668727
self,
669728
rollout_config: InferenceEngineConfig,

docs/en/algorithms/distillation.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# On-Policy Distillation
2+
3+
## Overview
4+
5+
On-policy distillation trains the student using teacher guidance on trajectories sampled from its own policy, reducing distribution mismatch and improving stability. Combined with reinforcement learning, it lets the student **imitate the teacher while exploring simultaneously**.
6+
7+
**AReaL** previously supported RL for post-training. With this implementation, it now also supports **on-policy knowledge distillation** and the **combined KDRL framework**, enabling the student to learn from a teacher while exploring via RL on the same on-policy trajectories, improving both efficiency and stability.
8+
9+
## The Core Concept
10+
11+
Knowledge distillation aims to train the student policy $\pi_\theta$ to mimic the behavior of a more powerful teacher $\pi_T$. The choice of divergence measure and sampling distribution significantly impacts the student's final performance and exposure bias.
12+
13+
### Supervised Fine-Tuning (Forward KL):
14+
15+
A simple yet effective method is to maximize the log-likelihood on data generated by the teacher, known as supervised fine-tuning (SFT). This is equivalent to minimizing the Forward KL divergence between $\pi_T$ and $\pi_\theta$:
16+
$$\arg \min_{\theta} D_{KL}(\pi_T \parallel \pi_\theta) = \arg \max_{\theta} \mathbb{E}_{q \sim Q, o \sim \pi_T(\cdot|q)} [\log \pi_\theta(o|q)]$$
17+
18+
19+
### On-Policy Distillation (Reverse KL):
20+
21+
While SFT is efficient, training on off-policy data induces exposure bias: a mismatch between training on teacher-generated prefixes and inference on self-generated prefixes. This is especially severe for reasoning LLMs with long response chains. To alleviate this, we can train on self-generated trajectories, which is equivalent to minimizing the Reverse KL divergence (RKL) [1]:
22+
$$\arg \min_{\theta} D_{KL}(\pi_\theta \parallel \pi_T) = \arg \max_{\theta} \mathbb{E}_{q \sim Q, o \sim \pi_\theta(\cdot|q)} \left[ \log \frac{\pi_T(o|q)}{\pi_\theta(o|q)} \right]$$
23+
24+
Minimizing RKL is equivalent to REINFORCE where the "reward" is the log-ratio of teacher to student probabilities. By adopting the GRPO framework, we optimize [1]:
25+
26+
$$J_{RKL}(\theta) = \mathbb{E}_{q, \{o_i\} \sim \pi_{\theta_{old}}} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \frac{\pi_\theta(o_{i,t})}{\pi_{\theta_{old}}(o_{i,t})} R_{i,t} \right]$$
27+
28+
where the reward $R_{i,t} = \log \pi_T(o_{i,t}) - \log \pi_\theta(o_{i,t})$. This encourages the student to increase the probability of tokens the teacher prefers and suppress those it deems unlikely.
29+
30+
- Implementation Detail:
31+
During pure KD, we need to set `rl_loss_weight` to 0, so the implementation estimates the RKL gradient using importance sampling. The code calculates the reward as teacher_logp - logprobs ($R_{i,t}$) and applies a negative coefficient to the loss to perform minimization (check `areal/trainer/ppo/actor.py`).
32+
33+
34+
### Combination of GRPO and KD
35+
We implemented KD+RL approach using a Joint Loss strategy.
36+
37+
#### Joint Loss:
38+
This strategy augments the GRPO objective with an auxiliary KL loss term. To maintain consistency with the on-policy nature of GRPO, it utilizes the Reverse KL (RKL) [1]:
39+
$$J_{KDRL}(\theta) = J_{GRPO}(\theta) - \beta D_{KL}(\pi_\theta \parallel \pi_T) \tag{8}$$
40+
41+
The gradient $\nabla_\theta J_{KDRL}(\theta)$ provides an unbiased estimate of $\nabla_\theta J_{GRPO}( \theta) + \beta \cdot \nabla_\theta J_{RKL}(\theta)$.
42+
43+
- Implementation Detail: In the joint loss case (`rl_loss_weight` > 0), the RKL is treated as a direct penalty. Minimizing the term `logprobs - teacher_logp` is mathematically equivalent to minimizing the Reverse KL objective $D_{KL}(\pi_\theta \parallel \pi_T)$ when sampling from the student distribution $\pi_\theta$. In the code, this is implemented as:
44+
`loss = rl_loss_weight * loss + distill_loss_weight * rkl_penalty`
45+
46+
47+
48+
49+
## Running the example
50+
51+
Need to add teacher configuration to your yaml:
52+
53+
```yaml
54+
teacher:
55+
allocation_mode: d1p1t4
56+
rl_loss_weight: 1.0
57+
distill_loss_weight: 0.005
58+
experiment_name: ${experiment_name}
59+
trial_name: ${trial_name}
60+
path: Qwen/Qwen3-32B
61+
init_from_scratch: false
62+
disable_dropout: true
63+
dtype: ${actor.dtype}
64+
mb_spec:
65+
max_tokens_per_mb: 10240
66+
optimizer: null
67+
scheduling_spec: ${actor.scheduling_spec}
68+
```
69+
70+
Example command using local scheduler:
71+
72+
```bash
73+
python3 examples/math/gsm8k_rl.py --config examples/distillation/gsm8k_grpo_distill.yaml scheduler.type=local experiment_name=gsm8k-grpo-distillation trial_name=trial0
74+
```
75+
76+
## Result
77+
78+
On-policy knowledge distillation + RL reward plot for Qwen2.5-14B-Instruct (teacher) and Qwen3-0.6B (student), trained using FSDP and vLLM.
79+
80+
![alt text](reward_curve.png)
81+
82+
## References
83+
84+
[1] Xu H, Zhu Q, Deng H, Li J, Hou L, Wang Y, Shang L, Xu R, Mi F. Kdrl: Post-training reasoning llms via unified knowledge distillation and reinforcement learning. [KDRL paper link](https://arxiv.org/pdf/2506.02208)
227 KB
Loading

0 commit comments

Comments
 (0)