Skip to content

Commit 01b767f

Browse files
h-guo18danielkorzekwa
authored andcommitted
Feat: Context Parallel for Eagle3 Training (#745)
**Type of change:** New Feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** - Supported Context Parallel by patching torch ring attention; - Require following libirary version for stable cp: - torch2.8.0 - transformers5.0.0 - accelrate1.12.0 - Move to FSDP2 - Removed unused arguments in training script (`--multi_gpu`, `fsdp_wrap_layer`) - Bump CI container to `nvcr.io/nvidia/pytorch:25.08-py3` <!-- You can potentially add a usage example below. --> ```bash ./launch_train.sh --model $MODEL \ --output_dir $OUTPUT_DIR \ --data $DATA \ --num_epochs 0.1 \ --train_bs 1 \ --eagle_config eagle_config.json \ --training_seq_len 1024 \ --cp_size 2 #newly added ``` - SDPA level correctness: tested TTT attention with/without CP, diff < 1% ``` === Compare context-parallel (CP) outputs and grads with non-CP === Forward output comparison (CP vs Non-CP): Absolute diff (adiff) cp_out vs out: 0.001953125 Relative diff (rdiff) cp_out vs out: 0.00182342529296875 WQ (query proj) grad comparison (CP vs Non-CP): Absolute diff (adiff) cp_wq_grad vs wq_grad: 0.0078125 Relative diff (rdiff) cp_wq_grad vs wq_grad: 0.00347900390625 WK (key proj) grad comparison (CP vs Non-CP): Absolute diff (adiff) cp_wk_grad vs wk_grad: 0.0078125 Relative diff (rdiff) cp_wk_grad vs wk_grad: 0.002471923828125 WV (value proj) grad comparison (CP vs Non-CP): Absolute diff (adiff) cp_wv_grad vs wv_grad: 0.25 Relative diff (rdiff) cp_wv_grad vs wv_grad: 0.0069580078125 ============================================================== ``` - E2E Training Acc (Llama3.1-8B, Unsynthesized magpie) <img width="911" height="630" alt="image" src="https://github.com/user-attachments/assets/1ecacc7f-c720-494c-9c1b-b60e7ced7baa" /> - Peak Mem Reserved (llama3.1-8B, 8xH100, train_length=4k) | cp_size | max_memory_allocated(MB) |max_memory_reserved (MB) | |----|--------------------------|--------------------------| | 1 | 65040.20 |79018.00 | 2 | 50409.17 |73098.00 | 4 | 45120.92 |72052.00 | 8 | 38882.12 |66484.00 - Max Training Length test (llama3.1-8B, H100) | cp_size | 6k | 12k | 24k | 48k | |--------------------|-----|-----|-----|-----| | 1 | ✅ | OOM | OOM | OOM | |2 | ✅ | ✅ | OOM | OOM | | 4 | ✅ | ✅ | ✅ | OOM | | 8 | ✅ | ✅ | ✅ | ✅ | <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **New Features** * Added context parallelism (CP) and data parallelism shard size configuration parameters to training arguments. * **Enhancements** * Improved TTT attention masking support for speculative decoding workflows. * Enhanced training launch script with improved parallelism configuration handling. * **Chores** * Updated core dependencies: torch, transformers, accelerate, and wandb. * Added FSDP configuration file for distributed training setup. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 0e706a8 commit 01b767f

12 files changed

Lines changed: 466 additions & 77 deletions

File tree

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
name: Example tests
2+
3+
on:
4+
push:
5+
branches: ["pull-request/[0-9]+"]
6+
# NOTE: paths cannot be used since push happens to copied PR and only latest commit to PR is used
7+
schedule:
8+
- cron: "0 0 * * *" # Nightly
9+
workflow_dispatch: # On-demand
10+
11+
# Cancel previous runs if new commit is pushed to the same PR
12+
concurrency:
13+
group: ${{ github.workflow }}-${{ startsWith(github.ref, 'refs/heads/pull-request/') && github.ref || github.sha }}
14+
cancel-in-progress: true
15+
16+
jobs:
17+
check-file-changes:
18+
if: startsWith(github.ref, 'refs/heads/pull-request/')
19+
runs-on: ubuntu-latest
20+
outputs:
21+
any_changed: ${{ steps.changed-tests.outputs.any_changed }}
22+
steps:
23+
- uses: actions/checkout@v6
24+
with:
25+
fetch-depth: 0
26+
- id: get-pr-info
27+
uses: nv-gha-runners/get-pr-info@main
28+
# Get commit from main branch that is present in the PR to use as base for changed files
29+
- id: calculate-merge-base
30+
env:
31+
PR_SHA: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).head.sha }}
32+
BASE_SHA: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).base.sha }}
33+
run: |
34+
(echo -n "merge-base="; git merge-base "$BASE_SHA" "$PR_SHA") | tee --append "${GITHUB_OUTPUT}"
35+
- name: Check for changes in test-relevant directories
36+
id: changed-tests
37+
uses: step-security/changed-files@v46.0.5
38+
with:
39+
base_sha: ${{ steps.calculate-merge-base.outputs.merge-base }}
40+
sha: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).head.sha }}
41+
files: |
42+
.github/workflows/example_tests.yml
43+
examples/**
44+
modelopt/**
45+
setup.py
46+
tests/examples/**
47+
fail_on_initial_diff_error: true
48+
wait-checks:
49+
needs: [check-file-changes]
50+
if: needs.check-file-changes.outputs.any_changed == 'true'
51+
uses: ./.github/workflows/_wait_for_checks.yml
52+
permissions:
53+
checks: read
54+
secrets: inherit
55+
with:
56+
match_pattern: "^DCO$|^linux$" # Wait for DCO and Unit tests / linux to pass
57+
delay: 300s
58+
59+
##### PyTorch Example Tests #####
60+
torch-pr:
61+
needs: [check-file-changes, wait-checks]
62+
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
63+
strategy:
64+
fail-fast: false
65+
matrix:
66+
example: [llm_distill, llm_qat, llm_sparsity]
67+
uses: ./.github/workflows/_example_tests_runner.yml
68+
secrets: inherit
69+
with:
70+
docker_image: "nvcr.io/nvidia/pytorch:25.06-py3"
71+
example: ${{ matrix.example }}
72+
pip_install_extras: "[hf,dev-test]"
73+
runner: linux-amd64-gpu-l4-latest-1
74+
75+
torch-non-pr:
76+
if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }}
77+
strategy:
78+
fail-fast: false
79+
matrix:
80+
example: [llm_distill, llm_qat, llm_sparsity]
81+
uses: ./.github/workflows/_example_tests_runner.yml
82+
secrets: inherit
83+
with:
84+
docker_image: "nvcr.io/nvidia/pytorch:25.06-py3"
85+
example: ${{ matrix.example }}
86+
pip_install_extras: "[hf,dev-test]"
87+
runner: linux-amd64-gpu-h100-latest-2
88+
89+
##### Speculative Decoding Example Tests (requires 25.08 image) #####
90+
speculative-decoding-pr:
91+
needs: [check-file-changes, wait-checks]
92+
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
93+
uses: ./.github/workflows/_example_tests_runner.yml
94+
secrets: inherit
95+
with:
96+
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
97+
example: speculative_decoding
98+
pip_install_extras: "[hf,dev-test]"
99+
runner: linux-amd64-gpu-l4-latest-1
100+
101+
speculative-decoding-non-pr:
102+
if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }}
103+
uses: ./.github/workflows/_example_tests_runner.yml
104+
secrets: inherit
105+
with:
106+
docker_image: "nvcr.io/nvidia/pytorch:25.08-py3"
107+
example: speculative_decoding
108+
pip_install_extras: "[hf,dev-test]"
109+
runner: linux-amd64-gpu-h100-latest-2
110+
111+
##### TensorRT-LLM Example Tests #####
112+
trtllm-pr:
113+
needs: [check-file-changes, wait-checks]
114+
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
115+
strategy:
116+
fail-fast: false
117+
matrix:
118+
example: [llm_ptq] # vlm_ptq temporarily disabled due to pipeline error
119+
uses: ./.github/workflows/_example_tests_runner.yml
120+
secrets: inherit
121+
with:
122+
docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc4"
123+
example: ${{ matrix.example }}
124+
pip_install_extras: "[hf,dev-test]"
125+
runner: linux-amd64-gpu-h100-latest-1
126+
127+
trtllm-non-pr:
128+
if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }}
129+
strategy:
130+
fail-fast: false
131+
matrix:
132+
example: [llm_autodeploy, llm_eval, llm_ptq, vlm_ptq]
133+
uses: ./.github/workflows/_example_tests_runner.yml
134+
secrets: inherit
135+
with:
136+
docker_image: "nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc4"
137+
example: ${{ matrix.example }}
138+
pip_install_extras: "[hf,dev-test]"
139+
runner: linux-amd64-gpu-h100-latest-2
140+
141+
##### ONNX/TensorRT Example Tests #####
142+
onnx-pr:
143+
needs: [check-file-changes, wait-checks]
144+
if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true'
145+
strategy:
146+
fail-fast: false
147+
matrix:
148+
example: [diffusers, torch_onnx]
149+
uses: ./.github/workflows/_example_tests_runner.yml
150+
secrets: inherit
151+
with:
152+
docker_image: "nvcr.io/nvidia/tensorrt:25.08-py3"
153+
example: ${{ matrix.example }}
154+
pip_install_extras: "[all,dev-test]"
155+
runner: linux-amd64-gpu-l4-latest-1
156+
157+
onnx-non-pr:
158+
if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }}
159+
strategy:
160+
fail-fast: false
161+
matrix:
162+
example: [diffusers, torch_onnx]
163+
uses: ./.github/workflows/_example_tests_runner.yml
164+
secrets: inherit
165+
with:
166+
docker_image: "nvcr.io/nvidia/tensorrt:25.08-py3"
167+
example: ${{ matrix.example }}
168+
pip_install_extras: "[all,dev-test]"
169+
runner: linux-amd64-gpu-l4-latest-1
170+
171+
##### Required Check for PR #####
172+
example-pr-required-check:
173+
# Run even if example tests are skipped
174+
if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }}
175+
needs: [check-file-changes, torch-pr, speculative-decoding-pr, trtllm-pr, onnx-pr]
176+
runs-on: ubuntu-latest
177+
steps:
178+
- name: Required GPU tests did not succeed
179+
if: |
180+
needs.check-file-changes.result != 'success' ||
181+
(needs.check-file-changes.outputs.any_changed == 'true' && (
182+
needs.torch-pr.result != 'success' ||
183+
needs.speculative-decoding-pr.result != 'success' ||
184+
needs.trtllm-pr.result != 'success' ||
185+
needs.onnx-pr.result != 'success'
186+
))
187+
run: exit 1

examples/speculative_decoding/README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM,
3030

3131
### Docker
3232

33-
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.06-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
33+
Please use the PyTorch docker image (e.g., `nvcr.io/nvidia/pytorch:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information.
3434

3535
Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install dataset and example-specific dependencies.
3636

@@ -56,7 +56,7 @@ See [other-datasets](#other-datasets) section for other dataset options and inst
5656
## Getting Started: Simplified Workflow
5757

5858
```bash
59-
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct --num_gpu 4
59+
bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct
6060
```
6161

6262
This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it
@@ -74,12 +74,11 @@ For small base models that fit in GPU memory, we can collocate them with draft m
7474
./launch_train.sh --model $BASE_MODEL \
7575
--output_dir $OUTPUT_DIR \
7676
--data input_conversations/daring-anteater.jsonl \
77-
--num_gpu $NUM_GPU \
7877
--num_epochs $NUM_EPOCH \
7978
--eagle_config eagle_config.json
8079
```
8180

82-
This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details.
81+
FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`.
8382
The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
8483

8584
## Training Draft Model with Offline Base Model
@@ -118,7 +117,6 @@ Once we finish dumping hidden states, launch offline training with an extra `--o
118117
./launch_train.sh --model $BASE_MODEL \
119118
--output_dir $OUTPUT_DIR \
120119
--data $DATA \
121-
--num_gpu $NUM_GPU \
122120
--num_epochs $NUM_EPOCH \
123121
--eagle_config eagle_config.json \
124122
--offline-data $HIDDEN_STATES_DIR

0 commit comments

Comments
 (0)