Skip to content

Commit 50eefcf

Browse files
houqiKnowingNothing
authored andcommitted
refactor: allreduce refactor to support more
* fix ci * fix ci * fix format * fix ci amd * fix ci * fix ci * fix e2e * fix ci * fix format * fix ci * fix test_e2e part * use sleep async * use dummy dtype for all-reduce workspace * launch with cooperative grids if use barrier_on_this_grid * fuse barrier all into triton kernel * fix allreduce_one_shot_push_intra_node * allreduce straggler with <10ms * remove allreduce with signal_stages * remove some not used layer buffer * refactor: some rename See merge request: !191
1 parent 316eabd commit 50eefcf

12 files changed

Lines changed: 127 additions & 198 deletions

File tree

python/triton_dist/layers/amd/tp_attn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,16 @@ def _init_parameters(self, self_attn: nn.Module, verbose=False):
207207
if verbose:
208208
print(f"[RANK {self.rank}] Attn initialized with parameters: qkv ({self.wqkv.shape}, o ({self.wo.shape}))")
209209

210-
def _init_ctx(self, max_M, gemm_stream, ag_intranode_stream, BLOCK_M, BLOCK_N, BLOCK_K, stages, serial=False,
210+
def _init_ctx(self, max_M, ag_intranode_stream, BLOCK_M, BLOCK_N, BLOCK_K, stages, serial=False,
211211
ag_internode_stream=None):
212212
if serial:
213213
print(f"[RANK {self.rank}] Using serial mode for AG-GEMM.")
214214
self.ag_ctx = create_ag_gemm_intra_node_context(max_M=max_M, N=self.ag_N_per_rank, K=self.K, rank=self.rank,
215215
num_ranks=self.world_size, input_dtype=self.dtype,
216216
output_dtype=self.dtype, tp_group=self.group,
217-
ag_streams=ag_intranode_stream, gemm_stream=gemm_stream,
218-
serial=serial, autotune=True, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
219-
BLOCK_K=BLOCK_K, stages=stages, M_PER_CHUNK=256)
217+
ag_streams=ag_intranode_stream, serial=serial, autotune=True,
218+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
219+
stages=stages, M_PER_CHUNK=256)
220220
self.rs_ctx = create_gemm_rs_intra_node_context(
221221
max_M=max_M,
222222
N=self.K,

python/triton_dist/layers/amd/tp_mlp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,17 @@ def _init_parameters(self, mlp: nn.Module, verbose=False):
8686
f"[RANK {self.rank}] MLP initialized with parameters: gate_up_proj shape: {self.gate_up_proj.shape}, down_proj shape: {self.down_proj.shape}"
8787
)
8888

89-
def _init_ctx(self, max_M, gemm_stream, ag_intranode_stream, BLOCK_M, BLOCK_N, BLOCK_K, stages, serial=False,
89+
def _init_ctx(self, max_M, ag_intranode_stream, BLOCK_M, BLOCK_N, BLOCK_K, stages, serial=False,
9090
ag_internode_stream=None):
9191
"""Initializes contexts for triton_dist AllGather-GEMM and GEMM-ReduceScatter operations."""
9292
if serial:
9393
print(f"[RANK {self.rank}] Using serial mode for AG-GEMM.")
9494
self.ag_ctx = create_ag_gemm_intra_node_context(max_M=max_M, N=self.ag_N_per_rank, K=self.K, rank=self.rank,
9595
num_ranks=self.world_size, input_dtype=self.dtype,
9696
output_dtype=self.dtype, tp_group=self.group,
97-
ag_streams=ag_intranode_stream, gemm_stream=gemm_stream,
98-
serial=serial, autotune=True, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
99-
BLOCK_K=BLOCK_K, stages=stages, M_PER_CHUNK=256)
97+
ag_streams=ag_intranode_stream, serial=serial, autotune=True,
98+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
99+
stages=stages, M_PER_CHUNK=256)
100100
self.rs_ctx = create_gemm_rs_intra_node_context(
101101
max_M=max_M,
102102
N=self.K,

python/triton_dist/layers/nvidia/tp_attn.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch.distributed
2929
import flashinfer
3030

31+
from triton_dist.kernels.allreduce import AllReduceMethod
3132
from triton_dist.kernels.nvidia.allgather_gemm import AllGatherGEMMTensorParallelContext, get_auto_all_gather_method, ag_gemm
3233
from triton_dist.kernels.nvidia import create_gemm_rs_context, gemm_rs
3334
from triton_dist.utils import nvshmem_barrier_all_on_stream
@@ -91,7 +92,7 @@ def __init__(self, rank=0, world_size=8, group=None):
9192
self.wo = None
9293
self.ag_ctx = None
9394
self.rs_ctx = None
94-
self.ctx = None
95+
self.ar_ctx = None
9596

9697
def _init_parameters(self, self_attn: nn.Module, verbose=False):
9798
self.q_size = self_attn.q_proj.weight.shape[0] // self.world_size
@@ -117,13 +118,11 @@ def _init_parameters(self, self_attn: nn.Module, verbose=False):
117118
if verbose:
118119
print(f"[RANK {self.rank}] Attn initialized with parameters: qkv ({self.wqkv.shape}, o ({self.wo.shape}))")
119120

120-
def _init_ctx(self, max_M, gemm_stream, ag_intranode_stream, ag_internode_stream, BLOCK_M, BLOCK_N, BLOCK_K,
121-
stages):
121+
def _init_ctx(self, max_M, ag_intranode_stream, ag_internode_stream, BLOCK_M, BLOCK_N, BLOCK_K, stages):
122122
self.ag_ctx = AllGatherGEMMTensorParallelContext(
123123
N_per_rank=self.ag_N_per_rank, K=self.K, tensor_dtype=self.dtype, rank=self.rank, num_ranks=self.world_size,
124-
num_local_ranks=self.world_size, max_M=max_M, gemm_stream=gemm_stream,
125-
ag_intranode_stream=ag_intranode_stream, ag_internode_stream=ag_internode_stream, BLOCK_M=BLOCK_M,
126-
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, stages=stages,
124+
num_local_ranks=self.world_size, max_M=max_M, ag_intranode_stream=ag_intranode_stream,
125+
ag_internode_stream=ag_internode_stream, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, stages=stages,
127126
all_gather_method=get_auto_all_gather_method(self.world_size, self.world_size))
128127
self.rs_ctx = create_gemm_rs_context(
129128
max_M=max_M,
@@ -141,27 +140,20 @@ def _init_ctx(self, max_M, gemm_stream, ag_intranode_stream, ag_internode_stream
141140
nvshmem_barrier_all_on_stream(torch.cuda.current_stream())
142141
torch.cuda.synchronize()
143142

144-
def _init_AR_ctx(self, max_M, method, dtype=torch.bfloat16, signal_stages=1):
143+
def _init_AR_ctx(self, max_M, method: AllReduceMethod, dtype=torch.bfloat16):
145144
self.ar_method = method
146-
input_tensor = torch.empty([max_M, self.K], dtype=dtype, device="meta")
147-
self.ctx = create_allreduce_ctx(
148-
numel=max_M * self.K,
149-
dtype=dtype,
150-
rank=self.rank,
151-
world_size=self.world_size,
145+
self.ar_ctx = create_allreduce_ctx(
146+
workspace_nbytes=max_M * self.K * dtype.itemsize, rank=self.rank, world_size=self.world_size,
152147
local_world_size=self.world_size, # TODO(houqi.1993) does not support multiple nodes now.
153-
method=method,
154-
signal_stages=signal_stages,
155148
)
156-
self.ar_output = torch.empty_like(input_tensor, device="cuda", dtype=dtype).contiguous()
157149

158150
def finalize(self):
159151
if self.ag_ctx:
160152
self.ag_ctx.finailize()
161153
if self.rs_ctx:
162154
self.rs_ctx.finalize()
163-
if self.ctx:
164-
self.ctx.finalize()
155+
if self.ar_ctx:
156+
self.ar_ctx.finalize()
165157

166158
@torch.inference_mode()
167159
def apply_rotary_pos_emb(self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor,
@@ -274,12 +266,8 @@ def dist_triton_AR_fwd(self, x, position_ids, cos_sin_cache, kv_cache, layer_idx
274266

275267
out = torch.nn.functional.linear(out.view(bsz, q_len, -1), self.wo).view(bsz * q_len, -1)
276268
if self.world_size > 1:
277-
out = all_reduce(
278-
input=out.contiguous(),
279-
output=self.ar_output,
280-
method=self.ar_method,
281-
ctx=self.ctx,
282-
)
269+
out_allreduce = torch.empty_like(out)
270+
out = all_reduce(x=out.contiguous(), output=out_allreduce, method=self.ar_method, ctx=self.ar_ctx)
283271
return out.view(bsz, q_len, -1)
284272

285273
def fwd(self, x: torch.Tensor, position_ids: torch.Tensor, cos_sin_cache: torch.Tensor, kv_cache, layer_idx: int):

python/triton_dist/layers/nvidia/tp_mlp.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch import nn
2828
import torch.distributed
2929

30+
from triton_dist.kernels.allreduce import AllReduceMethod
3031
from triton_dist.kernels.nvidia.allgather_gemm import AllGatherGEMMTensorParallelContext, get_auto_all_gather_method, ag_gemm
3132
from triton_dist.kernels.nvidia import create_gemm_rs_context, gemm_rs
3233
from triton_dist.utils import nvshmem_barrier_all_on_stream
@@ -64,7 +65,7 @@ def __init__(self, rank=0, world_size=8, group=None):
6465
self.down_proj = None
6566
self.ag_ctx = None
6667
self.rs_ctx = None
67-
self.ctx = None
68+
self.ar_ctx = None
6869

6970
def _init_parameters(self, mlp: nn.Module, verbose=False):
7071
"""
@@ -91,14 +92,13 @@ def _init_parameters(self, mlp: nn.Module, verbose=False):
9192
f"[RANK {self.rank}] MLP initialized with parameters: gate_up_proj shape: {self.gate_up_proj.shape}, down_proj shape: {self.down_proj.shape}"
9293
)
9394

94-
def _init_ctx(self, max_M, gemm_stream, ag_intranode_stream, ag_internode_stream, BLOCK_M, BLOCK_N, BLOCK_K,
95-
stages):
95+
def _init_ctx(self, max_M, ag_intranode_stream, ag_internode_stream, BLOCK_M, BLOCK_N, BLOCK_K, stages):
96+
# TODO(houqi.1993) BLOCK_SIZE should not be part of arguments, but be determined on forward.
9697
"""Initializes contexts for triton_dist AllGather-GEMM and GEMM-ReduceScatter operations."""
9798
self.ag_ctx = AllGatherGEMMTensorParallelContext(
9899
N_per_rank=self.ag_N_per_rank, K=self.K, tensor_dtype=self.dtype, rank=self.rank, num_ranks=self.world_size,
99-
num_local_ranks=self.world_size, max_M=max_M, gemm_stream=gemm_stream,
100-
ag_intranode_stream=ag_intranode_stream, ag_internode_stream=ag_internode_stream, BLOCK_M=BLOCK_M,
101-
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, stages=stages,
100+
num_local_ranks=self.world_size, max_M=max_M, ag_intranode_stream=ag_intranode_stream,
101+
ag_internode_stream=ag_internode_stream, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, stages=stages,
102102
all_gather_method=get_auto_all_gather_method(self.world_size, self.world_size))
103103
self.rs_ctx = create_gemm_rs_context(
104104
max_M=max_M,
@@ -121,8 +121,8 @@ def finalize(self):
121121
self.ag_ctx.finailize()
122122
if self.rs_ctx:
123123
self.rs_ctx.finalize()
124-
if self.ctx:
125-
self.ctx.finalize()
124+
if self.ar_ctx:
125+
self.ar_ctx.finalize()
126126

127127
@torch.inference_mode()
128128
def torch_fwd(self, x):
@@ -140,7 +140,7 @@ def torch_fwd(self, x):
140140
return out
141141

142142
@torch.inference_mode()
143-
def dist_triton_fwd(self, x, ag_gemm_persistent=False, gemm_rs_persistent=False, autotune=True):
143+
def dist_triton_fwd(self, x: torch.Tensor, ag_gemm_persistent=False, gemm_rs_persistent=False, autotune=True):
144144
"""
145145
triton_dist forward pass for TP.
146146
This version uses ag_gemm and gemm_rs.
@@ -165,22 +165,16 @@ def dist_triton_fwd(self, x, ag_gemm_persistent=False, gemm_rs_persistent=False,
165165
out = out.view(bsz, seq, -1)
166166
return out
167167

168-
def _init_AR_ctx(self, M, method, dtype=torch.bfloat16, signal_stages=1):
168+
def _init_AR_ctx(self, max_M, method: AllReduceMethod, dtype=torch.bfloat16):
169169
self.ar_method = method
170170
N = self.down_proj.shape[0]
171-
self.ctx = create_allreduce_ctx(
172-
numel=M * N,
173-
dtype=dtype,
174-
rank=self.rank,
175-
world_size=self.world_size,
171+
self.ar_ctx = create_allreduce_ctx(
172+
workspace_nbytes=max_M * N * dtype.itemsize, rank=self.rank, world_size=self.world_size,
176173
local_world_size=self.world_size, # TODO(houqi.1993) does not support multiple nodes now.
177-
method=method,
178-
signal_stages=signal_stages,
179174
)
180-
self.ar_output = torch.empty((M, N), device="cuda", dtype=dtype).contiguous()
181175

182176
@torch.inference_mode()
183-
def dist_triton_AR_fwd(self, x):
177+
def dist_triton_AR_fwd(self, x: torch.Tensor):
184178
"""
185179
triton_dist AR forward pass for TP.
186180
This version uses gemm + gemm + AllReduce
@@ -189,9 +183,11 @@ def dist_triton_AR_fwd(self, x):
189183
out_fused = torch.nn.functional.linear(x, self.gate_up_proj)
190184
wg, w1 = torch.chunk(out_fused, 2, dim=-1)
191185
out = self.act_fn(wg) * w1
192-
out = torch.nn.functional.linear(out, self.down_proj).view_as(self.ar_output)
186+
out = torch.nn.functional.linear(out, self.down_proj).view_as(x)
193187
if self.world_size > 1:
194-
out = all_reduce(out.contiguous(), self.ar_output, method=self.ar_method, ctx=self.ctx)
188+
out_ar = torch.empty_like(out)
189+
assert self.ar_ctx is not None, "AllReduce context is not initialized."
190+
out = all_reduce(out.contiguous(), out_ar, method=self.ar_method, ctx=self.ar_ctx)
195191
return out.view_as(x)
196192

197193
@torch.inference_mode()
@@ -203,20 +199,13 @@ def torch_ag_gemm(self, x: torch.Tensor):
203199
"""
204200
Reference PyTorch forward pass using AllGather-GEMM.
205201
"""
206-
207202
M_per_rank, K = x.shape
208203
M = M_per_rank * self.world_size
209-
210-
if not hasattr(self, 'ag_buffer'):
211-
self.ag_buffer = torch.empty([M, K], dtype=x.dtype, device="cuda")
212-
204+
ag_buffer = torch.empty([M, K], dtype=x.dtype, device="cuda")
213205
# ag
214-
torch.distributed.all_gather_into_tensor(self.ag_buffer, x, group=self.group)
215-
206+
torch.distributed.all_gather_into_tensor(ag_buffer, x, group=self.group)
216207
# gemm
217-
golden = torch.matmul(self.ag_buffer, self.gate_up_proj.T)
218-
219-
return golden
208+
return torch.matmul(ag_buffer, self.gate_up_proj.T)
220209

221210
@torch.inference_mode()
222211
def dist_triton_ag_gemm(self, x: torch.Tensor, persistent=True, autotune=False):
@@ -225,8 +214,8 @@ def dist_triton_ag_gemm(self, x: torch.Tensor, persistent=True, autotune=False):
225214
This version uses ag_gemm.
226215
x: input tensor, shape [batch_size * seq_len, hidden_size]
227216
"""
228-
out = ag_gemm(x, self.gate_up_proj, ctx=self.ag_ctx, persistent=persistent, autotune=autotune)
229-
return out
217+
assert self.ag_ctx is not None
218+
return ag_gemm(x, self.gate_up_proj, ctx=self.ag_ctx, persistent=persistent, autotune=autotune)
230219

231220
@torch.inference_mode()
232221
def torch_gemm_rs(self, x: torch.Tensor):
@@ -235,16 +224,11 @@ def torch_gemm_rs(self, x: torch.Tensor):
235224
"""
236225
# x: [M, K]
237226
M, K = x.shape
238-
if not hasattr(self, 'rs_buffer'):
239-
self.rs_buffer = torch.empty([M // self.world_size, self.down_proj.shape[0]], dtype=x.dtype, device="cuda")
240-
227+
rs_buffer = torch.empty([M // self.world_size, self.down_proj.shape[0]], dtype=x.dtype, device="cuda")
241228
# gemm
242229
gemm_out = torch.matmul(x, self.down_proj.T)
243-
244-
# rs
245-
torch.distributed.reduce_scatter_tensor(self.rs_buffer, gemm_out, group=self.group)
246-
247-
return self.rs_buffer
230+
torch.distributed.reduce_scatter_tensor(rs_buffer, gemm_out, group=self.group)
231+
return rs_buffer
248232

249233
@torch.inference_mode()
250234
def dist_triton_gemm_rs(self, x: torch.Tensor, persistent=False):
@@ -253,5 +237,5 @@ def dist_triton_gemm_rs(self, x: torch.Tensor, persistent=False):
253237
This version uses gemm_rs.
254238
x: input tensor, shape [batch_size * seq_len, hidden_size]
255239
"""
256-
out = gemm_rs(x, self.down_proj, self.rs_ctx, persistent=persistent, fuse_scatter=True)
257-
return out
240+
assert self.rs_ctx is not None
241+
return gemm_rs(x, self.down_proj, self.rs_ctx, persistent=persistent, fuse_scatter=True)

python/triton_dist/models/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tqdm import tqdm
2929
from datetime import datetime
3030

31+
from triton_dist.kernels.allreduce import AllReduceMethod
3132
from triton_dist.models.kv_cache import KV_Cache
3233
from triton_dist.models import AutoLLM, AutoTokenizer, ModelConfig
3334
from triton_dist.models.utils import logger, sample_token
@@ -128,7 +129,7 @@ def serve(self, input_ids: torch.Tensor, gen_len: int):
128129
self.model.init_triton_dist_ctx(max_M=bsz)
129130
elif self.backend == 'triton_dist_AR':
130131
self.model.set_fwd(mode='triton_dist_AR')
131-
self.model.init_triton_dist_AR_ctx(max_M=bsz, ar_method='two_shot_ld_reduce')
132+
self.model.init_triton_dist_AR_ctx(max_M=bsz, ar_method=AllReduceMethod.TwoShot_Multimem)
132133

133134
if self.no_graph:
134135

@@ -184,4 +185,3 @@ def run(input_ids, position_ids):
184185
print(self.tokenizer.batch_decode(output_ids, skip_special_tokens=True))
185186

186187
del self.model_launch
187-
torch.distributed.destroy_process_group()

python/triton_dist/models/qwen.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from transformers import Qwen3ForCausalLM, Qwen3Config
3131
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
3232

33+
from triton_dist.kernels.allreduce import AllReduceMethod
34+
3335
if not torch.cuda.is_available():
3436
raise RuntimeError("CUDA is not available. Please ensure you have a compatible GPU and CUDA installed.")
3537
try:
@@ -173,14 +175,11 @@ def init_triton_dist_ctx(self, max_M: int = 4096):
173175
self.ag_intranode_stream = [torch.cuda.Stream(priority=-1) for i in range(self.world_size)]
174176
else:
175177
raise RuntimeError(f"Unsupported platform: {PLATFORM}. Supported platforms are 'nvidia' and 'amd'.")
176-
self.gemm_stream = torch.cuda.Stream()
177178
self.ag_internode_stream = torch.cuda.Stream()
178-
self.layers[0].attn._init_ctx(max_M=max_M, gemm_stream=self.gemm_stream,
179-
ag_intranode_stream=self.ag_intranode_stream,
179+
self.layers[0].attn._init_ctx(max_M=max_M, ag_intranode_stream=self.ag_intranode_stream,
180180
ag_internode_stream=self.ag_internode_stream, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
181181
BLOCK_K=BLOCK_K, stages=stages)
182-
self.layers[0].mlp._init_ctx(max_M=max_M, gemm_stream=self.gemm_stream,
183-
ag_intranode_stream=self.ag_intranode_stream,
182+
self.layers[0].mlp._init_ctx(max_M=max_M, ag_intranode_stream=self.ag_intranode_stream,
184183
ag_internode_stream=self.ag_internode_stream, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
185184
BLOCK_K=BLOCK_K, stages=stages)
186185
for layer in self.layers[1:]:
@@ -191,17 +190,15 @@ def init_triton_dist_ctx(self, max_M: int = 4096):
191190

192191
self.use_ar = False
193192

194-
def init_triton_dist_AR_ctx(self, max_M: int = 128, ar_method: str = 'two_shot_ld_reduce'):
193+
def init_triton_dist_AR_ctx(self, max_M: int = 128, ar_method: AllReduceMethod = AllReduceMethod.DoubleTree):
195194
self.layers[0].attn._init_AR_ctx(max_M=max_M, method=ar_method, dtype=self.dtype)
196-
self.layers[0].mlp._init_AR_ctx(M=max_M, method=ar_method, dtype=self.dtype)
195+
self.layers[0].mlp._init_AR_ctx(max_M=max_M, method=ar_method, dtype=self.dtype)
197196

198197
for layer in self.layers[1:]:
199-
layer.attn.ctx = self.layers[0].attn.ctx
198+
layer.attn.ar_ctx = self.layers[0].attn.ar_ctx
200199
layer.attn.ar_method = self.layers[0].attn.ar_method
201-
layer.attn.ar_output = self.layers[0].attn.ar_output
202-
layer.mlp.ctx = self.layers[0].mlp.ctx
200+
layer.mlp.ar_ctx = self.layers[0].mlp.ar_ctx
203201
layer.mlp.ar_method = self.layers[0].mlp.ar_method
204-
layer.mlp.ar_output = self.layers[0].mlp.ar_output
205202
self.use_ar = True
206203

207204
def finalize(self):

0 commit comments

Comments
 (0)