2727from torch import nn
2828import torch .distributed
2929
30+ from triton_dist .kernels .allreduce import AllReduceMethod
3031from triton_dist .kernels .nvidia .allgather_gemm import AllGatherGEMMTensorParallelContext , get_auto_all_gather_method , ag_gemm
3132from triton_dist .kernels .nvidia import create_gemm_rs_context , gemm_rs
3233from triton_dist .utils import nvshmem_barrier_all_on_stream
@@ -64,7 +65,7 @@ def __init__(self, rank=0, world_size=8, group=None):
6465 self .down_proj = None
6566 self .ag_ctx = None
6667 self .rs_ctx = None
67- self .ctx = None
68+ self .ar_ctx = None
6869
6970 def _init_parameters (self , mlp : nn .Module , verbose = False ):
7071 """
@@ -91,14 +92,13 @@ def _init_parameters(self, mlp: nn.Module, verbose=False):
9192 f"[RANK { self .rank } ] MLP initialized with parameters: gate_up_proj shape: { self .gate_up_proj .shape } , down_proj shape: { self .down_proj .shape } "
9293 )
9394
94- def _init_ctx (self , max_M , gemm_stream , ag_intranode_stream , ag_internode_stream , BLOCK_M , BLOCK_N , BLOCK_K ,
95- stages ):
95+ def _init_ctx (self , max_M , ag_intranode_stream , ag_internode_stream , BLOCK_M , BLOCK_N , BLOCK_K , stages ):
96+ # TODO(houqi.1993) BLOCK_SIZE should not be part of arguments, but be determined on forward.
9697 """Initializes contexts for triton_dist AllGather-GEMM and GEMM-ReduceScatter operations."""
9798 self .ag_ctx = AllGatherGEMMTensorParallelContext (
9899 N_per_rank = self .ag_N_per_rank , K = self .K , tensor_dtype = self .dtype , rank = self .rank , num_ranks = self .world_size ,
99- num_local_ranks = self .world_size , max_M = max_M , gemm_stream = gemm_stream ,
100- ag_intranode_stream = ag_intranode_stream , ag_internode_stream = ag_internode_stream , BLOCK_M = BLOCK_M ,
101- BLOCK_N = BLOCK_N , BLOCK_K = BLOCK_K , stages = stages ,
100+ num_local_ranks = self .world_size , max_M = max_M , ag_intranode_stream = ag_intranode_stream ,
101+ ag_internode_stream = ag_internode_stream , BLOCK_M = BLOCK_M , BLOCK_N = BLOCK_N , BLOCK_K = BLOCK_K , stages = stages ,
102102 all_gather_method = get_auto_all_gather_method (self .world_size , self .world_size ))
103103 self .rs_ctx = create_gemm_rs_context (
104104 max_M = max_M ,
@@ -121,8 +121,8 @@ def finalize(self):
121121 self .ag_ctx .finailize ()
122122 if self .rs_ctx :
123123 self .rs_ctx .finalize ()
124- if self .ctx :
125- self .ctx .finalize ()
124+ if self .ar_ctx :
125+ self .ar_ctx .finalize ()
126126
127127 @torch .inference_mode ()
128128 def torch_fwd (self , x ):
@@ -140,7 +140,7 @@ def torch_fwd(self, x):
140140 return out
141141
142142 @torch .inference_mode ()
143- def dist_triton_fwd (self , x , ag_gemm_persistent = False , gemm_rs_persistent = False , autotune = True ):
143+ def dist_triton_fwd (self , x : torch . Tensor , ag_gemm_persistent = False , gemm_rs_persistent = False , autotune = True ):
144144 """
145145 triton_dist forward pass for TP.
146146 This version uses ag_gemm and gemm_rs.
@@ -165,22 +165,16 @@ def dist_triton_fwd(self, x, ag_gemm_persistent=False, gemm_rs_persistent=False,
165165 out = out .view (bsz , seq , - 1 )
166166 return out
167167
168- def _init_AR_ctx (self , M , method , dtype = torch .bfloat16 , signal_stages = 1 ):
168+ def _init_AR_ctx (self , max_M , method : AllReduceMethod , dtype = torch .bfloat16 ):
169169 self .ar_method = method
170170 N = self .down_proj .shape [0 ]
171- self .ctx = create_allreduce_ctx (
172- numel = M * N ,
173- dtype = dtype ,
174- rank = self .rank ,
175- world_size = self .world_size ,
171+ self .ar_ctx = create_allreduce_ctx (
172+ workspace_nbytes = max_M * N * dtype .itemsize , rank = self .rank , world_size = self .world_size ,
176173 local_world_size = self .world_size , # TODO(houqi.1993) does not support multiple nodes now.
177- method = method ,
178- signal_stages = signal_stages ,
179174 )
180- self .ar_output = torch .empty ((M , N ), device = "cuda" , dtype = dtype ).contiguous ()
181175
182176 @torch .inference_mode ()
183- def dist_triton_AR_fwd (self , x ):
177+ def dist_triton_AR_fwd (self , x : torch . Tensor ):
184178 """
185179 triton_dist AR forward pass for TP.
186180 This version uses gemm + gemm + AllReduce
@@ -189,9 +183,11 @@ def dist_triton_AR_fwd(self, x):
189183 out_fused = torch .nn .functional .linear (x , self .gate_up_proj )
190184 wg , w1 = torch .chunk (out_fused , 2 , dim = - 1 )
191185 out = self .act_fn (wg ) * w1
192- out = torch .nn .functional .linear (out , self .down_proj ).view_as (self . ar_output )
186+ out = torch .nn .functional .linear (out , self .down_proj ).view_as (x )
193187 if self .world_size > 1 :
194- out = all_reduce (out .contiguous (), self .ar_output , method = self .ar_method , ctx = self .ctx )
188+ out_ar = torch .empty_like (out )
189+ assert self .ar_ctx is not None , "AllReduce context is not initialized."
190+ out = all_reduce (out .contiguous (), out_ar , method = self .ar_method , ctx = self .ar_ctx )
195191 return out .view_as (x )
196192
197193 @torch .inference_mode ()
@@ -203,20 +199,13 @@ def torch_ag_gemm(self, x: torch.Tensor):
203199 """
204200 Reference PyTorch forward pass using AllGather-GEMM.
205201 """
206-
207202 M_per_rank , K = x .shape
208203 M = M_per_rank * self .world_size
209-
210- if not hasattr (self , 'ag_buffer' ):
211- self .ag_buffer = torch .empty ([M , K ], dtype = x .dtype , device = "cuda" )
212-
204+ ag_buffer = torch .empty ([M , K ], dtype = x .dtype , device = "cuda" )
213205 # ag
214- torch .distributed .all_gather_into_tensor (self .ag_buffer , x , group = self .group )
215-
206+ torch .distributed .all_gather_into_tensor (ag_buffer , x , group = self .group )
216207 # gemm
217- golden = torch .matmul (self .ag_buffer , self .gate_up_proj .T )
218-
219- return golden
208+ return torch .matmul (ag_buffer , self .gate_up_proj .T )
220209
221210 @torch .inference_mode ()
222211 def dist_triton_ag_gemm (self , x : torch .Tensor , persistent = True , autotune = False ):
@@ -225,8 +214,8 @@ def dist_triton_ag_gemm(self, x: torch.Tensor, persistent=True, autotune=False):
225214 This version uses ag_gemm.
226215 x: input tensor, shape [batch_size * seq_len, hidden_size]
227216 """
228- out = ag_gemm ( x , self .gate_up_proj , ctx = self . ag_ctx , persistent = persistent , autotune = autotune )
229- return out
217+ assert self .ag_ctx is not None
218+ return ag_gemm ( x , self . gate_up_proj , ctx = self . ag_ctx , persistent = persistent , autotune = autotune )
230219
231220 @torch .inference_mode ()
232221 def torch_gemm_rs (self , x : torch .Tensor ):
@@ -235,16 +224,11 @@ def torch_gemm_rs(self, x: torch.Tensor):
235224 """
236225 # x: [M, K]
237226 M , K = x .shape
238- if not hasattr (self , 'rs_buffer' ):
239- self .rs_buffer = torch .empty ([M // self .world_size , self .down_proj .shape [0 ]], dtype = x .dtype , device = "cuda" )
240-
227+ rs_buffer = torch .empty ([M // self .world_size , self .down_proj .shape [0 ]], dtype = x .dtype , device = "cuda" )
241228 # gemm
242229 gemm_out = torch .matmul (x , self .down_proj .T )
243-
244- # rs
245- torch .distributed .reduce_scatter_tensor (self .rs_buffer , gemm_out , group = self .group )
246-
247- return self .rs_buffer
230+ torch .distributed .reduce_scatter_tensor (rs_buffer , gemm_out , group = self .group )
231+ return rs_buffer
248232
249233 @torch .inference_mode ()
250234 def dist_triton_gemm_rs (self , x : torch .Tensor , persistent = False ):
@@ -253,5 +237,5 @@ def dist_triton_gemm_rs(self, x: torch.Tensor, persistent=False):
253237 This version uses gemm_rs.
254238 x: input tensor, shape [batch_size * seq_len, hidden_size]
255239 """
256- out = gemm_rs ( x , self .down_proj , self . rs_ctx , persistent = persistent , fuse_scatter = True )
257- return out
240+ assert self .rs_ctx is not None
241+ return gemm_rs ( x , self . down_proj , self . rs_ctx , persistent = persistent , fuse_scatter = True )
0 commit comments