Commit 3036a9e
authored
Feat: Context Parallel for Eagle3 Training (#745)
## What does this PR do?
**Type of change:** New Feature <!-- Use one of the following: Bug fix,
new feature, new example, new tests, documentation. -->
**Overview:**
- Supported Context Parallel by patching torch ring attention;
- Require following libirary version for stable cp:
- torch2.8.0
- transformers5.0.0
- accelrate1.12.0
- Move to FSDP2
- Removed unused arguments in training script (`--multi_gpu`,
`fsdp_wrap_layer`)
- Bump CI container to `nvcr.io/nvidia/pytorch:25.08-py3`
## Usage
<!-- You can potentially add a usage example below. -->
```bash
./launch_train.sh --model $MODEL \
--output_dir $OUTPUT_DIR \
--data $DATA \
--num_epochs 0.1 \
--train_bs 1 \
--eagle_config eagle_config.json \
--training_seq_len 1024 \
--cp_size 2 #newly added
```
## Testing
- SDPA level correctness: tested TTT attention with/without CP, diff <
1%
```
=== Compare context-parallel (CP) outputs and grads with non-CP ===
Forward output comparison (CP vs Non-CP):
Absolute diff (adiff) cp_out vs out: 0.001953125
Relative diff (rdiff) cp_out vs out: 0.00182342529296875
WQ (query proj) grad comparison (CP vs Non-CP):
Absolute diff (adiff) cp_wq_grad vs wq_grad: 0.0078125
Relative diff (rdiff) cp_wq_grad vs wq_grad: 0.00347900390625
WK (key proj) grad comparison (CP vs Non-CP):
Absolute diff (adiff) cp_wk_grad vs wk_grad: 0.0078125
Relative diff (rdiff) cp_wk_grad vs wk_grad: 0.002471923828125
WV (value proj) grad comparison (CP vs Non-CP):
Absolute diff (adiff) cp_wv_grad vs wv_grad: 0.25
Relative diff (rdiff) cp_wv_grad vs wv_grad: 0.0069580078125
==============================================================
```
- E2E Training Acc
(Llama3.1-8B, Unsynthesized magpie)
<img width="911" height="630" alt="image"
src="https://github.com/user-attachments/assets/1ecacc7f-c720-494c-9c1b-b60e7ced7baa"
/>
- Peak Mem Reserved
(llama3.1-8B, 8xH100, train_length=4k)
| cp_size | max_memory_allocated(MB) |max_memory_reserved (MB) |
|----|--------------------------|--------------------------|
| 1 | 65040.20 |79018.00
| 2 | 50409.17 |73098.00
| 4 | 45120.92 |72052.00
| 8 | 38882.12 |66484.00
- Max Training Length test
(llama3.1-8B, H100)
| cp_size | 6k | 12k | 24k | 48k |
|--------------------|-----|-----|-----|-----|
| 1 | ✅ | OOM | OOM | OOM |
|2 | ✅ | ✅ | OOM | OOM |
| 4 | ✅ | ✅ | ✅ | OOM |
| 8 | ✅ | ✅ | ✅ | ✅ |
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->
## Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added context parallelism (CP) and data parallelism shard size
configuration parameters to training arguments.
* **Enhancements**
* Improved TTT attention masking support for speculative decoding
workflows.
* Enhanced training launch script with improved parallelism
configuration handling.
* **Chores**
* Updated core dependencies: torch, transformers, accelerate, and wandb.
* Added FSDP configuration file for distributed training setup.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>1 parent 04165ac commit 3036a9e
File tree
12 files changed
+305
-80
lines changed- .github/workflows
- examples/speculative_decoding
- modelopt/torch/speculative
- plugins
- tests/examples/speculative_decoding
12 files changed
+305
-80
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
63 | 63 | | |
64 | 64 | | |
65 | 65 | | |
66 | | - | |
| 66 | + | |
67 | 67 | | |
68 | 68 | | |
69 | 69 | | |
| |||
77 | 77 | | |
78 | 78 | | |
79 | 79 | | |
80 | | - | |
| 80 | + | |
81 | 81 | | |
82 | 82 | | |
83 | 83 | | |
| |||
86 | 86 | | |
87 | 87 | | |
88 | 88 | | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
89 | 111 | | |
90 | 112 | | |
91 | 113 | | |
| |||
150 | 172 | | |
151 | 173 | | |
152 | 174 | | |
153 | | - | |
| 175 | + | |
154 | 176 | | |
155 | 177 | | |
156 | 178 | | |
157 | 179 | | |
158 | 180 | | |
159 | 181 | | |
160 | 182 | | |
| 183 | + | |
161 | 184 | | |
162 | 185 | | |
163 | 186 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
33 | | - | |
| 33 | + | |
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
| |||
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
59 | | - | |
| 59 | + | |
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
| |||
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
77 | | - | |
78 | 77 | | |
79 | 78 | | |
80 | 79 | | |
81 | 80 | | |
82 | | - | |
| 81 | + | |
83 | 82 | | |
84 | 83 | | |
85 | 84 | | |
| |||
118 | 117 | | |
119 | 118 | | |
120 | 119 | | |
121 | | - | |
122 | 120 | | |
123 | 121 | | |
124 | 122 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
16 | 17 | | |
17 | 18 | | |
| 19 | + | |
18 | 20 | | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
19 | 25 | | |
20 | 26 | | |
21 | 27 | | |
22 | 28 | | |
23 | 29 | | |
24 | 30 | | |
| 31 | + | |
25 | 32 | | |
26 | 33 | | |
| 34 | + | |
27 | 35 | | |
28 | 36 | | |
29 | 37 | | |
30 | 38 | | |
| 39 | + | |
| 40 | + | |
31 | 41 | | |
32 | 42 | | |
33 | 43 | | |
| |||
566 | 576 | | |
567 | 577 | | |
568 | 578 | | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
| 624 | + | |
| 625 | + | |
| 626 | + | |
| 627 | + | |
| 628 | + | |
| 629 | + | |
| 630 | + | |
| 631 | + | |
| 632 | + | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
| 637 | + | |
| 638 | + | |
| 639 | + | |
| 640 | + | |
| 641 | + | |
| 642 | + | |
| 643 | + | |
| 644 | + | |
| 645 | + | |
| 646 | + | |
| 647 | + | |
| 648 | + | |
| 649 | + | |
| 650 | + | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
| 654 | + | |
| 655 | + | |
| 656 | + | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
| 689 | + | |
| 690 | + | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | 77 | | |
86 | 78 | | |
87 | 79 | | |
| |||
102 | 94 | | |
103 | 95 | | |
104 | 96 | | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
105 | 105 | | |
106 | 106 | | |
107 | 107 | | |
| |||
129 | 129 | | |
130 | 130 | | |
131 | 131 | | |
132 | | - | |
133 | | - | |
134 | 132 | | |
135 | 133 | | |
136 | 134 | | |
137 | 135 | | |
138 | 136 | | |
139 | 137 | | |
140 | 138 | | |
| 139 | + | |
| 140 | + | |
141 | 141 | | |
142 | 142 | | |
143 | 143 | | |
| |||
163 | 163 | | |
164 | 164 | | |
165 | 165 | | |
166 | | - | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
171 | 166 | | |
172 | 167 | | |
173 | 168 | | |
174 | 169 | | |
175 | 170 | | |
176 | 171 | | |
177 | 172 | | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
178 | 182 | | |
179 | 183 | | |
180 | | - | |
| 184 | + | |
181 | 185 | | |
182 | 186 | | |
183 | 187 | | |
| |||
206 | 210 | | |
207 | 211 | | |
208 | 212 | | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
209 | 216 | | |
210 | 217 | | |
211 | 218 | | |
| |||
0 commit comments