Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
b5e0b07 to
ea36677
Compare
📝 WalkthroughWalkthroughIntroduces TTT (Temporal To True) attention masking support for context parallelism in speculative decoding. Adds attention masking computation and ring attention patching utilities. Updates training pipeline to conditionally apply CP-TTT patches when context parallelism is enabled. Refactors cache initialization and loss masking alignment. Changes
Sequence Diagram(s)sequenceDiagram
participant Training as Training<br/>Pipeline
participant CP as CP Config<br/>(cp_size > 1)
participant Patch as Ring Attention<br/>Patcher
participant RingAttn as Ring Attention<br/>Module
participant TTT as TTT Mask<br/>Generator
participant Loss as Loss<br/>Computation
Training->>CP: check cp_size > 1
activate CP
alt cp_size > 1
CP->>Patch: call patch_ring_attention_for_ttt()
activate Patch
Patch->>Patch: extract rank, size, query, key, ttt_step from frame
Patch->>RingAttn: replace _templated_ring_attention with patched version
Patch->>RingAttn: patch _SDPAMerger.step to skip benign shards
Patch->>RingAttn: disable CP load balancing
deactivate Patch
Training->>RingAttn: forward pass (attention computation)
activate RingAttn
RingAttn->>TTT: inject TTT mask into attention bias
activate TTT
TTT->>TTT: compute composite mask with kv_idx and ttt_step
TTT-->>RingAttn: return ttt_attention_mask
deactivate TTT
RingAttn-->>Training: attention output with TTT masking
deactivate RingAttn
else cp_size == 1
CP->>Training: skip patching
end
deactivate CP
Training->>Loss: forward pass (Eagle model)
activate Loss
Loss->>Loss: align loss_mask to eagle_logits shape
Loss->>Loss: compute softmax cross-entropy with aligned masks
Loss-->>Training: loss value
deactivate Loss
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🤖 Fix all issues with AI agents
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 661-662: The error message string in the conditional that checks
patch_enbabled and original_op against
torch.ops.aten._scaled_dot_product_cudnn_attention contains a typo ("cuddn");
update the ValueError text to use the correct spelling "cudnn" so the raised
message reads something like "CP TTT only supports cudnn attention now. Got:
{original_op}" while keeping the same condition and variables (patch_enbabled,
original_op, torch.ops.aten._scaled_dot_product_cudnn_attention).
- Around line 668-688: The patched_op wrapper uses
inspect.currentframe()/frame.f_back and populates rank,size,query,key,i,ttt_step
inside a try/except but then continues even if inspection failed, which can lead
to NameError; update patched_op to validate that inspect.currentframe() and
frame.f_back are not None and that f_back.f_locals contains the expected keys
before using them (inspect.currentframe, frame.f_back, f_back.f_locals, keys
"rank","size","query","key","i"); if any check fails, either re-raise the caught
exception or return/raise a clear error early so the function does not proceed
to call _get_sharded_ttt_msk or original_op with undefined variables; ensure
ttt_step is computed only after query/key are present and preserve the
original_op call path when inspection succeeds.
- Around line 700-701: Fix the typo in the inline comment above the config
assignment: change "permenantly" to "permanently" in the comment that precedes
the
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance
= False line so the comment reads "So need to be done permanently before
accelerate/hf trainer init."
- Around line 580-598: Delete the dead _compute_ttt_attention_mask function in
eagle_utils.py: remove the entire function definition (def
_compute_ttt_attention_mask(batch_size, seq_length, ttt_step, dtype) ->
torch.Tensor:) because it is unused and its docstring is misleading; ensure
there are no remaining references to this symbol and rely on the existing
implementation in transformers.py (the plugin’s _compute_ttt_attention_mask)
which handles flex_attention correctly.
- Around line 720-725: The function patched_sdpa_merger_step has an incorrect
return annotation: change its signature from "-> torch.Tensor" to "-> None"
because it performs in-place mutations and returns None (match original
_SDPAMerger.step); update the annotation on patched_sdpa_merger_step and ensure
the call to original_sdpa_merger_step is used only for its side effects. Also
optionally review the lse.sum() <= 0 check in patched_sdpa_merger_step to
confirm it correctly identifies blank shards (log-sum-exp can be negative) and
adjust the condition if needed.
In `@examples/speculative_decoding/requirements.txt`:
- Around line 1-4: The requirements file currently leaves the dependency "wandb"
unpinned; update the requirements.txt entry "wandb" to a specific, tested
version (e.g., replace the bare token "wandb" with a pinned version like
"wandb==<chosen_version>") to ensure reproducible installs; choose and document
a stable release you validated against the existing pinned packages and commit
the updated line.
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 65-66: CACHED_SHARD_TTT_MASKS is defined but unused; either delete
the constant or add a clarifying comment and intended usage so it isn't flagged
as dead code—update the module-level definition near ENABLE_CP_TTT_PATCH by
removing the CACHED_SHARD_TTT_MASKS = {} line if it's not needed, or replace it
with a short doc comment (e.g., explaining expected keys/values and when it
should be populated) and keep the name as-is so future readers know its purpose;
ensure no other code references are required by also grepping for
CACHED_SHARD_TTT_MASKS before removal.
🧹 Nitpick comments (6)
modelopt/torch/speculative/utils.py (1)
462-470: Consider exception safety in the context manager.The context manager correctly restores the flag in the
finallyblock. However, thesdpa_kernelcontext manager wraps thetry/yield/finally, which means ifsdpa_kernelraises during entry, the flag might already be set toTruebut won't be reset.Consider this safer ordering:
♻️ Suggested improvement
`@contextlib.contextmanager` def enable_cp_ttt_patch(): """Context manager to enable CP TTT patch.""" - modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True - with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): - try: - yield - finally: - modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False + try: + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + yield + finally: + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = Falsemodelopt/torch/speculative/plugins/transformers.py (1)
905-918: Consider makingenable_cp_ttt_patch()conditional on CP being active.The
enable_cp_ttt_patch()context manager is applied unconditionally during training, but the TODO on line 905 suggests the mask isn't used during CP training. This forces CUDNN_ATTENTION backend even whencp_size=1, which may have unintended performance characteristics.Consider wrapping conditionally:
♻️ Suggested approach
+ ctx = enable_cp_ttt_patch() if ENABLE_CP_TTT_PATCH else contextlib.nullcontext() - with enable_cp_ttt_patch(): + with ctx: _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward(This would require importing
contextliband checkingENABLE_CP_TTT_PATCHstatus.examples/speculative_decoding/launch_train.sh (2)
97-104: Inconsistent argument naming:--dp_sizevsdp_shard_size.The CLI argument is
--dp_sizebut it setsDP_SHARD_SIZEand the Python code usesdp_shard_size. Consider aligning the CLI argument name for consistency:♻️ Suggested fix
- --dp_size*) + --dp_shard_size*) if [[ "$1" != *=* ]]; then shift; fi DP_SHARD_SIZE="${1#*=}" ;;
139-140: Add validation for CP_SIZE and GPU_COUNT relationship.The calculation
DP_SHARD_SIZE=$((GPU_COUNT/CP_SIZE))could result in 0 ifCP_SIZE > GPU_COUNT, or leave resources unused ifGPU_COUNTis not evenly divisible byCP_SIZE.Consider adding validation:
♻️ Suggested validation
CP_SIZE=${CP_SIZE:-1} +if [[ $CP_SIZE -gt $GPU_COUNT ]]; then + echo "Error: cp_size ($CP_SIZE) cannot exceed GPU count ($GPU_COUNT)" + exit 1 +fi +if [[ $((GPU_COUNT % CP_SIZE)) -ne 0 ]]; then + echo "Warning: GPU count ($GPU_COUNT) is not evenly divisible by cp_size ($CP_SIZE)" +fi DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}examples/speculative_decoding/main.py (1)
141-147: ParallelismConfig API usage is correct; existing comment sufficiently tracks the workaround.The
ParallelismConfigcorrectly accepts bothcp_sizeanddp_shard_sizeparameters (verified against accelerate documentation). Thesp_backend = Noneworkaround for accelerate 1.12.0 is appropriate and already has a clear comment indicating removal after upgrade to 1.13.0.Optionally, consider adding a linked TODO (e.g., with a GitHub issue or PR reference) to the existing comment to streamline tracking of the deprecation.
examples/speculative_decoding/eagle_utils.py (1)
652-692: Frame inspection is fragile and tightly coupled to torch internals.The approach of using
inspect.currentframe().f_back.f_localsto extract variables from PyTorch's internal_templated_ring_attentionimplementation is inherently fragile. Any change to variable names or control flow in PyTorch's implementation will silently break this code.Consider:
- Adding a comment documenting which PyTorch version this was tested against (torch 2.8.0 per PR).
- Adding a runtime assertion or version check to fail fast if the expected variables aren't found.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #745 +/- ##
==========================================
- Coverage 74.17% 74.13% -0.05%
==========================================
Files 192 192
Lines 19246 19263 +17
==========================================
+ Hits 14276 14280 +4
- Misses 4970 4983 +13 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
a2b4c55 to
9b11f6c
Compare
6673a02 to
30452f3
Compare
kevalmorabia97
left a comment
There was a problem hiding this comment.
LGTM from CICD point of view
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
0140274 to
670d675
Compare
**Type of change:** New Feature <!-- Use one of the following: Bug fix,
new feature, new example, new tests, documentation. -->
**Overview:**
- Supported Context Parallel by patching torch ring attention;
- Require following libirary version for stable cp:
- torch2.8.0
- transformers5.0.0
- accelrate1.12.0
- Move to FSDP2
- Removed unused arguments in training script (`--multi_gpu`,
`fsdp_wrap_layer`)
- Bump CI container to `nvcr.io/nvidia/pytorch:25.08-py3`
<!-- You can potentially add a usage example below. -->
```bash
./launch_train.sh --model $MODEL \
--output_dir $OUTPUT_DIR \
--data $DATA \
--num_epochs 0.1 \
--train_bs 1 \
--eagle_config eagle_config.json \
--training_seq_len 1024 \
--cp_size 2 #newly added
```
- SDPA level correctness: tested TTT attention with/without CP, diff <
1%
```
=== Compare context-parallel (CP) outputs and grads with non-CP ===
Forward output comparison (CP vs Non-CP):
Absolute diff (adiff) cp_out vs out: 0.001953125
Relative diff (rdiff) cp_out vs out: 0.00182342529296875
WQ (query proj) grad comparison (CP vs Non-CP):
Absolute diff (adiff) cp_wq_grad vs wq_grad: 0.0078125
Relative diff (rdiff) cp_wq_grad vs wq_grad: 0.00347900390625
WK (key proj) grad comparison (CP vs Non-CP):
Absolute diff (adiff) cp_wk_grad vs wk_grad: 0.0078125
Relative diff (rdiff) cp_wk_grad vs wk_grad: 0.002471923828125
WV (value proj) grad comparison (CP vs Non-CP):
Absolute diff (adiff) cp_wv_grad vs wv_grad: 0.25
Relative diff (rdiff) cp_wv_grad vs wv_grad: 0.0069580078125
==============================================================
```
- E2E Training Acc
(Llama3.1-8B, Unsynthesized magpie)
<img width="911" height="630" alt="image"
src="https://github.com/user-attachments/assets/1ecacc7f-c720-494c-9c1b-b60e7ced7baa"
/>
- Peak Mem Reserved
(llama3.1-8B, 8xH100, train_length=4k)
| cp_size | max_memory_allocated(MB) |max_memory_reserved (MB) |
|----|--------------------------|--------------------------|
| 1 | 65040.20 |79018.00
| 2 | 50409.17 |73098.00
| 4 | 45120.92 |72052.00
| 8 | 38882.12 |66484.00
- Max Training Length test
(llama3.1-8B, H100)
| cp_size | 6k | 12k | 24k | 48k |
|--------------------|-----|-----|-----|-----|
| 1 | ✅ | OOM | OOM | OOM |
|2 | ✅ | ✅ | OOM | OOM |
| 4 | ✅ | ✅ | ✅ | OOM |
| 8 | ✅ | ✅ | ✅ | ✅ |
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
* **New Features**
* Added context parallelism (CP) and data parallelism shard size
configuration parameters to training arguments.
* **Enhancements**
* Improved TTT attention masking support for speculative decoding
workflows.
* Enhanced training launch script with improved parallelism
configuration handling.
* **Chores**
* Updated core dependencies: torch, transformers, accelerate, and wandb.
* Added FSDP configuration file for distributed training setup.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
**Type of change:** New Feature <!-- Use one of the following: Bug fix,
new feature, new example, new tests, documentation. -->
**Overview:**
- Supported Context Parallel by patching torch ring attention;
- Require following libirary version for stable cp:
- torch2.8.0
- transformers5.0.0
- accelrate1.12.0
- Move to FSDP2
- Removed unused arguments in training script (`--multi_gpu`,
`fsdp_wrap_layer`)
- Bump CI container to `nvcr.io/nvidia/pytorch:25.08-py3`
<!-- You can potentially add a usage example below. -->
```bash
./launch_train.sh --model $MODEL \
--output_dir $OUTPUT_DIR \
--data $DATA \
--num_epochs 0.1 \
--train_bs 1 \
--eagle_config eagle_config.json \
--training_seq_len 1024 \
--cp_size 2 #newly added
```
- SDPA level correctness: tested TTT attention with/without CP, diff <
1%
```
=== Compare context-parallel (CP) outputs and grads with non-CP ===
Forward output comparison (CP vs Non-CP):
Absolute diff (adiff) cp_out vs out: 0.001953125
Relative diff (rdiff) cp_out vs out: 0.00182342529296875
WQ (query proj) grad comparison (CP vs Non-CP):
Absolute diff (adiff) cp_wq_grad vs wq_grad: 0.0078125
Relative diff (rdiff) cp_wq_grad vs wq_grad: 0.00347900390625
WK (key proj) grad comparison (CP vs Non-CP):
Absolute diff (adiff) cp_wk_grad vs wk_grad: 0.0078125
Relative diff (rdiff) cp_wk_grad vs wk_grad: 0.002471923828125
WV (value proj) grad comparison (CP vs Non-CP):
Absolute diff (adiff) cp_wv_grad vs wv_grad: 0.25
Relative diff (rdiff) cp_wv_grad vs wv_grad: 0.0069580078125
==============================================================
```
- E2E Training Acc
(Llama3.1-8B, Unsynthesized magpie)
<img width="911" height="630" alt="image"
src="https://github.com/user-attachments/assets/1ecacc7f-c720-494c-9c1b-b60e7ced7baa"
/>
- Peak Mem Reserved
(llama3.1-8B, 8xH100, train_length=4k)
| cp_size | max_memory_allocated(MB) |max_memory_reserved (MB) |
|----|--------------------------|--------------------------|
| 1 | 65040.20 |79018.00
| 2 | 50409.17 |73098.00
| 4 | 45120.92 |72052.00
| 8 | 38882.12 |66484.00
- Max Training Length test
(llama3.1-8B, H100)
| cp_size | 6k | 12k | 24k | 48k |
|--------------------|-----|-----|-----|-----|
| 1 | ✅ | OOM | OOM | OOM |
|2 | ✅ | ✅ | OOM | OOM |
| 4 | ✅ | ✅ | ✅ | OOM |
| 8 | ✅ | ✅ | ✅ | ✅ |
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
* **New Features**
* Added context parallelism (CP) and data parallelism shard size
configuration parameters to training arguments.
* **Enhancements**
* Improved TTT attention masking support for speculative decoding
workflows.
* Enhanced training launch script with improved parallelism
configuration handling.
* **Chores**
* Updated core dependencies: torch, transformers, accelerate, and wandb.
* Added FSDP configuration file for distributed training setup.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Type of change: New Feature
Overview:
--multi_gpu,fsdp_wrap_layer)nvcr.io/nvidia/pytorch:25.08-py3Usage
Testing
(Llama3.1-8B, Unsynthesized magpie)
Peak Mem Reserved
(llama3.1-8B, 8xH100, train_length=4k)
Max Training Length test
(llama3.1-8B, H100)
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Enhancements
Chores
✏️ Tip: You can customize this high-level summary in your review settings.