Skip to content

Commit 4f80f5d

Browse files
shanjiazMeganEFlynnclaudefynnsu
authored
P-eagle training support (#480)
<!-- markdownlint-disable --> ## Purpose P-EAGLE inherits EAGLE-3's architecture but introduces parallel prediction. The approach uses the same lightweight decoder architecture as EAGLE-3 but generates multiple token predictions in parallel through Conditional-On-Distribution (COD) sampling, rather than sequential test-time training steps. This PR implements the P-EAGLE algorithm based on existing EAGLE3 infrastructure. <!--- Why your changes are needed --> ## Description - Model definition: `speculators/models/peagle/core.py` — P-EAGLE draft model with parallel multi-token prediction via COD sampling, flex attention for cross-depth masking, and a learnable `mask_hidden` parameter for padding unsampled positions - Metrics: Extracted loss and accuracy computation into `peagle/metrics.py`. Includes count-based normalization for correct distributed averaging when COD causes different ranks to sample different depths. - Trainer: Added `normalize_counted_metrics` to `train/utils.py` to handle per-position accuracy averaging across ranks, with minimal changes to the shared trainer. - Training script: Added P-EAGLE config/args to `scripts/train.py` and an example training script for Qwen3-8B ## Related Issue #292 ## Tests - Added a training example for Qwen3-8B with ShareGPT 5k samples - Will post more validation results. I have filled in: - [x] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". - [x] The test plan/results, such as providing test command and pasting the results. - [ ] (Optional) The necessary documentation update. - [x] I (a human) have written or reviewed the code in this pr to the best of my ability. --------- Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com> Co-authored-by: Megan Flynn <mflynn@redhat.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
1 parent 3c811bc commit 4f80f5d

19 files changed

Lines changed: 848 additions & 109 deletions

File tree

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#!/bin/bash
2+
# Online P-EAGLE Training Script
3+
#
4+
# Runs the full online P-EAGLE training pipeline: data preparation, vLLM server launch,
5+
# and training (with hidden states generated on-the-fly from the live server).
6+
#
7+
# Usage: Copy this script, modify the configuration variables below, then run:
8+
# bash examples/train/peagle_qwen3_8b_sharegpt_online_5k.sh
9+
10+
### Example E2E run for P-EAGLE Qwen3-8B on 5k samples from ShareGPT ###
11+
12+
# P-EAGLE (Parallel EAGLE) extends EAGLE-3 with parallel multi-token prediction using
13+
# Conditional-On-Distribution (COD) sampling for memory-efficient training.
14+
15+
# Note: With just 5k samples, the model performance will not be very good, however there
16+
# are enough samples to verify that the pipeline is working correctly and that the model
17+
# is learning something. This is a good sanity check when creating a drafter for a new
18+
# target model.
19+
20+
# Timing (on 4x NVIDIA H100 80GB GPUs, DP=2)
21+
# Data Preprocessing: 26 seconds
22+
# vLLM Server Startup: 82 seconds (1 min 22 secs)
23+
# Training (5 epochs): 2793 seconds (46 mins 33 secs)
24+
# Total: 2901 seconds (48 mins 21 secs)
25+
26+
# Results on SpecBench (80 prompts, 256 output tokens):
27+
# acceptance rate: 13.35%
28+
# acceptance length: 1.53
29+
# per-position acceptance:
30+
# position 0: 40.84%
31+
# position 1: 10.84%
32+
# position 2: 1.58%
33+
# position 3: 0.15%
34+
35+
set -euo pipefail
36+
37+
# ============ Configuration ============
38+
MODEL="Qwen/Qwen3-8B"
39+
DATASET="sharegpt" # sharegpt, ultrachat, or path to custom data
40+
OUTPUT_DIR="./output/peagle_qwen3_8b_sharegpt"
41+
VLLM_PORT=8108
42+
MAX_SAMPLES=5000
43+
SEQ_LENGTH=4096
44+
EPOCHS=5
45+
LR=6e-4
46+
47+
# P-EAGLE-specific parameters
48+
SPECULATOR_TYPE="peagle"
49+
NUM_LAYERS=4
50+
NUM_DEPTHS=4
51+
DOWN_SAMPLE_RATIO=0.7
52+
DOWN_SAMPLE_RATIO_MIN=0.2
53+
# GPU assignments (online training needs separate GPUs for vLLM and training)
54+
VLLM_GPUS="2,3"
55+
TRAIN_GPUS="4,5"
56+
NUM_TRAIN_GPUS=2
57+
# =======================================
58+
59+
# Step 1: Prepare data
60+
echo "=== Step 1: Preparing data ==="
61+
python scripts/prepare_data.py \
62+
--model "$MODEL" \
63+
--data "$DATASET" \
64+
--output "$OUTPUT_DIR" \
65+
--max-samples "$MAX_SAMPLES" \
66+
--seq-length "$SEQ_LENGTH"
67+
68+
# Step 2: Launch vLLM server in the background
69+
echo "=== Step 2: Launching vLLM server ==="
70+
CUDA_VISIBLE_DEVICES="$VLLM_GPUS" python scripts/launch_vllm.py "$MODEL" \
71+
--hidden-states-path "$OUTPUT_DIR/hidden_states" \
72+
-- --data-parallel-size 2 --port "$VLLM_PORT" &
73+
VLLM_PID=$!
74+
75+
# Ensure vLLM is cleaned up on exit
76+
cleanup() {
77+
echo "Stopping vLLM server..."
78+
kill "$VLLM_PID" 2>/dev/null || true
79+
wait "$VLLM_PID" 2>/dev/null || true
80+
}
81+
trap cleanup EXIT
82+
83+
echo "Waiting for vLLM server to be ready..."
84+
until curl -sf "http://localhost:${VLLM_PORT}/health" > /dev/null 2>&1; do
85+
sleep 2
86+
done
87+
echo "vLLM server ready."
88+
89+
# Step 3: Train against the live vLLM server
90+
echo "=== Step 3: Training ==="
91+
CUDA_VISIBLE_DEVICES="$TRAIN_GPUS" torchrun \
92+
--standalone --nproc_per_node "$NUM_TRAIN_GPUS" \
93+
scripts/train.py \
94+
--verifier-name-or-path "$MODEL" \
95+
--data-path "$OUTPUT_DIR" \
96+
--vllm-endpoint "http://localhost:${VLLM_PORT}/v1" \
97+
--hidden-states-path "$OUTPUT_DIR/hidden_states" \
98+
--save-path "$OUTPUT_DIR/checkpoints" \
99+
--epochs "$EPOCHS" \
100+
--lr "$LR" \
101+
--total-seq-len "$SEQ_LENGTH" \
102+
--speculator-type "$SPECULATOR_TYPE" \
103+
--num-layers "$NUM_LAYERS" \
104+
--num-depths "$NUM_DEPTHS" \
105+
--down-sample-ratio "$DOWN_SAMPLE_RATIO" \
106+
--down-sample-ratio-min "$DOWN_SAMPLE_RATIO_MIN" \
107+
--no-norm-before-residual \
108+
--scheduler-type cosine \
109+
--on-missing generate \
110+
--on-generate delete
111+
112+
echo "Done. Checkpoints saved to $OUTPUT_DIR/checkpoints/"

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ ignore_missing_imports=true
143143
[tool.ruff]
144144
line-length = 88
145145
indent-width = 4
146-
exclude = ["build", "dist", "env", ".venv"]
146+
exclude = ["build", "dist", "env", ".venv", "output"]
147147

148148
[tool.ruff.format]
149149
quote-style = "double"
@@ -249,13 +249,13 @@ select = [
249249
"INP001", # allow implicit namespace packages in scripts
250250
"PTH", # os.path is acceptable in scripts
251251
"T201", # print statements are acceptable in scripts
252+
"SLF001", # allow private member access for model configuration
252253
]
253254

254255
"examples/**/*.py" = [
255256
"INP001", # allow implicit namespace packages in examples
256257
]
257258

258-
259259
[tool.ruff.lint.isort]
260260
known-first-party = ["speculators", "tests"]
261261

scripts/train.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,22 +269,22 @@ def main(args: argparse.Namespace):
269269
)
270270

271271
model_class = registry[args.speculator_type]
272+
272273
if args.from_pretrained:
273274
draft_model = model_class.from_pretrained(
274275
args.from_pretrained, t2d=t2d, d2t=d2t
275276
)
276277
else:
277-
args_dict = vars(args)
278-
args_dict["draft_vocab_size"] = draft_vocab_size
278+
args.draft_vocab_size = draft_vocab_size
279279
draft_model = model_class.from_training_args(
280280
verifier_config=transformer_layer_config,
281281
t2d=t2d,
282282
d2t=d2t,
283-
**args_dict,
283+
**vars(args),
284284
)
285285

286286
# Setup dataloaders
287-
preprocess = shift_batch if args.speculator_type == "eagle3" else None
287+
preprocess = shift_batch if args.speculator_type in ("eagle3", "peagle") else None
288288

289289
noise_transform = AddUniformNoise(std=args.noise_std)
290290
if args.legacy_data:
@@ -608,6 +608,7 @@ def parse_args():
608608
help="Use RMSNorm before fc in Eagle3 draft path "
609609
"(e.g. for gpt-oss). Omit for other models.",
610610
)
611+
# D-Flash specific parameters
611612
parser.add_argument(
612613
"--block-size",
613614
type=int,
@@ -620,6 +621,25 @@ def parse_args():
620621
default=256,
621622
help="Maximum anchor positions for DFlash training (default: 256)",
622623
)
624+
# P-EAGLE specific parameters
625+
parser.add_argument(
626+
"--num-depths",
627+
type=int,
628+
default=8,
629+
help="Number of parallel prediction depths for P-EAGLE (default: 8)",
630+
)
631+
parser.add_argument(
632+
"--down-sample-ratio",
633+
type=float,
634+
default=0.7,
635+
help="Geometric decay ratio for COD sampling in P-EAGLE (default: 0.7)",
636+
)
637+
parser.add_argument(
638+
"--down-sample-ratio-min",
639+
type=float,
640+
default=0.2,
641+
help="Minimum retention ratio for COD sampling in P-EAGLE (default: 0.2)",
642+
)
623643
# Dataloader parameters
624644
parser.add_argument(
625645
"--num-workers", type=int, default=12, help="Number of dataloader workers"

src/speculators/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
from .dflash import DFlashDraftModel, DFlashSpeculatorConfig
44
from .eagle3 import Eagle3DraftModel, Eagle3SpeculatorConfig
5+
from .peagle import PEagleDraftModel, PEagleSpeculatorConfig
56

67
__all__ = [
78
"DFlashDraftModel",
89
"DFlashSpeculatorConfig",
910
"Eagle3DraftModel",
1011
"Eagle3SpeculatorConfig",
12+
"PEagleDraftModel",
13+
"PEagleSpeculatorConfig",
1114
]

src/speculators/models/dflash/core.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,8 @@ def from_training_args(
106106
verifier_config: Verifier model configuration. This should be a config
107107
with num_hidden_layers set to the number of DRAFT layers (created
108108
by create_transformer_layer_config in train.py).
109-
t2d: Target-to-draft vocabulary mapping tensor (optional, creates
110-
identity mapping if None)
111-
d2t: Draft-to-target vocabulary mapping tensor (optional, creates
112-
identity mapping if None)
109+
t2d: Target-to-draft vocabulary mapping tensor (optional)
110+
d2t: Draft-to-target vocabulary mapping tensor (optional)
113111
**kwargs: Training arguments with DFlash-specific params
114112
- draft_vocab_size: Size of draft vocabulary
115113
- block_size: Block size for draft predictions (default: 8)
@@ -158,14 +156,6 @@ def from_training_args(
158156
),
159157
)
160158

161-
# Create identity mappings if t2d/d2t not provided (no vocab reduction)
162-
if t2d is None or d2t is None:
163-
vocab_size = kwargs["draft_vocab_size"]
164-
# t2d: all tokens in target vocab are in draft vocab
165-
t2d = torch.ones(vocab_size, dtype=torch.bool)
166-
# d2t: identity mapping (zero offset for all tokens)
167-
d2t = torch.zeros(vocab_size, dtype=torch.long)
168-
169159
model = cls(config=config)
170160
model.load_vocab_mappings(t2d, d2t)
171161
model.load_verifier_weights()

src/speculators/models/dflash/metrics.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,18 @@ def compute_metrics(
5151
pred_ids = torch.argmax(logits, dim=-1)
5252
target_ids = torch.argmax(targets, dim=-1)
5353

54-
full_acc, per_position_acc = compute_accuracy_multi_step(
54+
correct_per_pos, total_per_pos = compute_accuracy_multi_step(
5555
pred_ids, target_ids, loss_mask, pos_idx, block_size
5656
)
5757

5858
metrics: dict[str, Any] = {}
59-
metrics["loss"] = loss.detach().clone()
60-
metrics["full_acc"] = full_acc
59+
metrics["loss_sum"] = loss.detach().clone()
60+
metrics["loss_total"] = torch.tensor(1.0, device=logits.device)
61+
# Position 0 is the anchor — intentionally excluded from accuracy
62+
metrics["full_acc_sum"] = correct_per_pos[1:].sum()
63+
metrics["full_acc_total"] = total_per_pos[1:].sum()
6164

62-
# Intentionally drop position 0
63-
for pos in range(1, len(per_position_acc)):
64-
metrics[f"position {pos} acc"] = per_position_acc[pos]
65+
for pos in range(1, block_size):
66+
metrics[f"position_{pos}_acc_sum"] = correct_per_pos[pos]
67+
metrics[f"position_{pos}_acc_total"] = total_per_pos[pos]
6568
return loss, metrics

src/speculators/models/eagle3/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def conditional_torch_compile(func):
2828
@SpeculatorModel.register("eagle3")
2929
class Eagle3DraftModel(DraftVocabMixin, SpeculatorModel):
3030
config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc]
31-
_keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc,assignment]
31+
_keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc]
3232
"embed_tokens.weight",
3333
"verifier_norm.weight",
3434
"verifier_lm_head.weight",
@@ -255,7 +255,8 @@ def forward( # noqa: C901
255255
# shape: [1, total_seq_len]
256256

257257
if return_loss:
258-
metrics["loss"] = loss.detach().clone()
258+
metrics["loss_sum"] = loss.detach().clone()
259+
metrics["loss_total"] = torch.tensor(1.0, device=device)
259260
return draft_tokens, loss, metrics
260261
else:
261262
return draft_tokens

src/speculators/models/eagle3/metrics.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,16 @@ def compute_metrics(
9595
pred_ids = torch.argmax(s_logits, dim=-1)
9696
target_ids = torch.argmax(s_targets, dim=-1)
9797

98-
s_full_acc, s_cond_acc = compute_accuracy_single_step(
98+
full_correct, full_total, cond_correct, cond_total = compute_accuracy_single_step(
9999
pred_ids, target_ids, s_loss_mask, s_prev_correct
100100
)
101101

102102
s_metrics = {}
103-
s_metrics[f"loss_{ttt_step}"] = s_loss.detach().clone()
104-
s_metrics[f"full_acc_{ttt_step}"] = s_full_acc
105-
s_metrics[f"cond_acc_{ttt_step}"] = s_cond_acc
103+
s_metrics[f"loss_{ttt_step}_sum"] = s_loss.detach().clone()
104+
s_metrics[f"loss_{ttt_step}_total"] = torch.tensor(1.0, device=s_loss.device)
105+
s_metrics[f"full_acc_{ttt_step}_sum"] = full_correct
106+
s_metrics[f"full_acc_{ttt_step}_total"] = full_total
107+
s_metrics[f"cond_acc_{ttt_step}_sum"] = cond_correct
108+
s_metrics[f"cond_acc_{ttt_step}_total"] = cond_total
106109

107110
return s_loss, s_metrics

src/speculators/models/metrics.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def compute_accuracy_single_step(
1111
loss_mask: torch.Tensor | None, # shape: [1, seq_len]
1212
prev_correct: torch.Tensor | None, # shape: [1, seq_len]
1313
):
14-
"""Compute full and conditional accuracy for a single speculative step.
14+
"""Compute full and conditional accuracy counts for a single speculative step.
1515
1616
Args:
1717
pred_ids: Predicted token IDs.
@@ -21,22 +21,21 @@ def compute_accuracy_single_step(
2121
via logical AND with the current step's correctness.
2222
2323
Returns:
24-
Tuple of (full_accuracy, conditional_accuracy) where conditional accuracy
25-
is accuracy given all previous steps were also correct.
24+
Tuple of (full_correct, full_total, cond_correct, cond_total) as raw
25+
counts suitable for distributed reduction before computing ratios.
2626
"""
2727
correct = pred_ids == target_ids
28-
cond_denom: torch.Tensor | int = correct.numel()
28+
cond_total = torch.tensor(correct.numel(), dtype=torch.float, device=correct.device)
2929
if prev_correct is not None:
30-
cond_denom = prev_correct.sum()
31-
# Update prev_correct in place
30+
cond_total = prev_correct.sum().float()
3231
correct = torch.logical_and(prev_correct, correct, out=prev_correct)
3332
if loss_mask is not None:
3433
correct = torch.masked_select(correct, loss_mask.to(torch.bool))
3534

3635
correct_sum = correct.float().sum()
37-
full_denom = correct.numel()
36+
full_total = torch.tensor(correct.numel(), dtype=torch.float, device=correct.device)
3837

39-
return correct_sum / (full_denom + _EPS), correct_sum / (cond_denom + _EPS)
38+
return correct_sum, full_total, correct_sum, cond_total
4039

4140

4241
@torch.no_grad()
@@ -47,7 +46,7 @@ def compute_accuracy_multi_step(
4746
pos_idx: torch.Tensor, # shape: [1, seq_len]
4847
num_pos: int,
4948
) -> tuple[torch.Tensor, torch.Tensor]:
50-
"""Compute overall and per-position accuracy across multiple speculative steps.
49+
"""Compute per-position correct/total counts across multiple speculative steps.
5150
5251
Args:
5352
pred_ids: Predicted token IDs.
@@ -57,24 +56,19 @@ def compute_accuracy_multi_step(
5756
num_pos: Number of distinct positions (i.e. block size).
5857
5958
Returns:
60-
Tuple of (overall_accuracy, per_position_accuracy) where per_position_accuracy
61-
has shape [num_pos].
59+
Tuple of (correct_per_pos, total_per_pos) both with shape [num_pos].
60+
Overall counts can be derived by summing these.
6261
"""
6362
correct = pred_ids == target_ids
6463
correct = torch.masked_select(correct, loss_mask.to(torch.bool))
6564
pos_idx = torch.masked_select(pos_idx, loss_mask.to(torch.bool))
6665

67-
correct_sum = correct.float().sum()
68-
full_denom = correct.numel()
69-
overall_acc = correct_sum / (full_denom + _EPS)
70-
71-
sums = torch.zeros(num_pos, dtype=torch.long, device=correct.device)
72-
counts = torch.zeros(num_pos, dtype=torch.long, device=correct.device)
73-
sums.scatter_add_(0, pos_idx, correct.long())
74-
counts.scatter_add_(0, pos_idx, torch.ones_like(correct, dtype=torch.long))
75-
per_pos_idx_acc = sums.float() / (counts.float() + _EPS)
66+
correct_per_pos = torch.zeros(num_pos, dtype=torch.float, device=correct.device)
67+
total_per_pos = torch.zeros(num_pos, dtype=torch.float, device=correct.device)
68+
correct_per_pos.scatter_add_(0, pos_idx, correct.float())
69+
total_per_pos.scatter_add_(0, pos_idx, torch.ones_like(correct, dtype=torch.float))
7670

77-
return overall_acc, per_pos_idx_acc # shape: [], [block_size]
71+
return correct_per_pos, total_per_pos # shape: [num_pos], [num_pos]
7872

7973

8074
def kl_div_loss(
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from speculators.models.peagle.config import PEagleSpeculatorConfig
2+
from speculators.models.peagle.core import PEagleDraftModel
3+
4+
__all__ = [
5+
"PEagleDraftModel",
6+
"PEagleSpeculatorConfig",
7+
]

0 commit comments

Comments
 (0)