-
Notifications
You must be signed in to change notification settings - Fork 730
docs: expand comm gemm overlap guidance #3043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,81 @@ | |
| - Devices older than compute capability 9.0 require `UB_SKIPMC=1` in the environment in order to fall | ||
| back on a less performant implementation based on CUDA Inter-Process Communication (IPC) handles. | ||
|
|
||
| ## Enabling overlap in your own module | ||
|
|
||
| The example follows the same setup sequence that user code should use: | ||
|
|
||
| 1. Set `CUDA_DEVICE_MAX_CONNECTIONS=1` before creating the layer. | ||
| 2. Initialize `torch.distributed` and create the tensor-parallel process group. | ||
| 3. Call `te.module.base.initialize_ub(...)` with the local activation shape and tensor-parallel | ||
| size before constructing TE layers with userbuffer overlap enabled. | ||
| 4. Pass the tensor-parallel group, tensor-parallel size, and overlap flags to the TE layer. | ||
| 5. Call `te.module.base.destroy_ub()` before shutting down the process group. | ||
|
|
||
| Minimal setup sketch: | ||
|
|
||
| ```python | ||
| import os | ||
| import torch | ||
| import torch.distributed as dist | ||
| import transformer_engine.pytorch as te | ||
|
|
||
| os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" | ||
|
|
||
| dist.init_process_group(backend="nccl") | ||
| tp_group = dist.group.WORLD | ||
| tp_size = dist.get_world_size(tp_group) | ||
|
|
||
| num_heads = 16 | ||
| head_dim = 128 | ||
| seq_length = 2048 | ||
| micro_batch_size = 4 | ||
|
|
||
| hidden_size = num_heads * head_dim | ||
| batched_size = seq_length * micro_batch_size | ||
|
|
||
| te.module.base.initialize_ub( | ||
| [batched_size, hidden_size], | ||
| tp_size, | ||
| quantization_modes=[te.module.base.UserBufferQuantizationMode.NONE], | ||
| dtype=torch.bfloat16, | ||
| bootstrap_backend="nccl", | ||
| ) | ||
|
|
||
| layer = te.TransformerLayer( | ||
| hidden_size, | ||
| 4 * hidden_size, | ||
| num_heads, | ||
| tp_group=tp_group, | ||
| tp_size=tp_size, | ||
| sequence_parallel=True, | ||
| fuse_qkv_params=True, | ||
| ub_tp_comm_overlap=True, | ||
| ub_overlap_ag=True, | ||
| ub_overlap_rs=True, | ||
| ub_bulk_wgrad=True, | ||
| ub_bulk_dgrad=True, | ||
| seq_length=seq_length, | ||
| ) | ||
|
Comment on lines
+42
to
+67
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The sketch uses |
||
|
|
||
| # ... run forward/backward/optimizer steps ... | ||
|
|
||
| te.module.base.destroy_ub() | ||
| ``` | ||
|
|
||
| `ub_tp_comm_overlap` is the top-level gate on `TransformerLayer`: when it is `False`, the | ||
| layer disables the individual userbuffer overlap paths even if the per-path flags are `True`. | ||
| For lower-level layers such as `Linear`, `LayerNormLinear`, `LayerNormMLP`, or | ||
| `MultiheadAttention`, enable the relevant per-path flags directly (for example | ||
| `ub_overlap_ag`, `ub_overlap_rs`, `ub_bulk_wgrad`, and `ub_bulk_dgrad`) and set the `ub_name` | ||
| where the layer requires one. | ||
|
|
||
| When replacing modules in a Hugging Face model, run the userbuffer initialization once before | ||
| constructing the replacement TE modules. The replacement modules need the same tensor-parallel | ||
| group, tensor-parallel size, sequence-parallel setting, and overlap flags shown above; the | ||
| activation shape passed to `initialize_ub` should match the sequence length, micro-batch size, | ||
| and hidden size used by the replaced blocks. | ||
|
|
||
| ## Examples | ||
|
|
||
| ### Single node, tensor-parallel LayerNormMLP: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initialize_ubcall in the sketch omits thequantization_modesargument that the actual example passes. For FP8 workloads the default (None) differs from the FP8 mode used inte_layer_with_overlap.py. A brief comment noting thatquantization_modesshould be set for FP8 training would prevent silent misconfiguration.