@@ -69,11 +69,11 @@ def _set_cluster_n(self):
6969 @cute .jit
7070 def __call__ (
7171 self ,
72- mX : cute .Tensor ,
73- mW : Optional [cute .Tensor ],
74- mB : Optional [cute .Tensor ],
75- mRes : Optional [cute .Tensor ],
76- mO : cute .Tensor ,
72+ mX : cute .Tensor , # (b, N) or (b, H, N)
73+ mW : Optional [cute .Tensor ], # (N,) or (H, N)
74+ mB : Optional [cute .Tensor ], # (N,) or (H, N)
75+ mRes : Optional [cute .Tensor ], # (b, N) or (b, H, N)
76+ mO : cute .Tensor , # (b, N) or (b, H, N)
7777 mResO : Optional [cute .Tensor ],
7878 mRstd : Optional [cute .Tensor ],
7979 mMean : Optional [cute .Tensor ],
@@ -93,13 +93,16 @@ def __call__(
9393 for mT in (mW , mB )
9494 ]
9595 mRstd , mMean = [
96- layout_utils .expand (mT , dim = 1 , size = self .N ) if const_expr (mT is not None ) else None
96+ layout_utils .expand (mT , dim = cute .rank (mT ), size = self .N )
97+ if const_expr (mT is not None )
98+ else None
9799 for mT in (mRstd , mMean )
98100 ]
101+ num_heads = mX .shape [1 ] if const_expr (cute .rank (mX ) == 3 ) else 1
99102 self .kernel (
100103 mX , mW , mB , mRes , mO , mResO , mRstd , mMean , eps , tiler_mn , tiled_copy , threads_per_row
101104 ).launch (
102- grid = [cute .ceil_div (mX .shape [0 ], tiler_mn [0 ]), self .cluster_n , 1 ],
105+ grid = [cute .ceil_div (mX .shape [0 ], tiler_mn [0 ]), self .cluster_n , num_heads ],
103106 block = [num_threads , 1 , 1 ],
104107 cluster = [1 , self .cluster_n , 1 ] if const_expr (self .cluster_n > 1 ) else None ,
105108 stream = stream ,
@@ -122,7 +125,7 @@ def kernel(
122125 threads_per_row : cutlass .Constexpr [int ],
123126 ):
124127 tidx , _ , _ = cute .arch .thread_idx ()
125- bidx , _ , _ = cute .arch .block_idx ()
128+ bidx , _ , bidz = cute .arch .block_idx ()
126129 cluster_y = const_expr (0 ) if const_expr (self .cluster_n == 1 ) else cute .arch .block_idx ()[1 ]
127130 tv_layout = tiled_copy .layout_tv_tiled
128131
@@ -138,9 +141,16 @@ def kernel(
138141 )
139142 reduction_buffer , mbar_ptr = self ._allocate_reduction_buffer_and_mbar (smem , tv_layout )
140143
141- shape = mX .shape
144+ # Slice per head
145+ if const_expr (cute .rank (mX ) == 3 ):
146+ mX , mW , mB , mRes , mO , mResO , mRstd , mMean = [
147+ mT [None , bidz , None ] if const_expr (mT is not None ) else None
148+ for mT in (mX , mW , mB , mRes , mO , mResO , mRstd , mMean )
149+ ]
150+
151+ shape = (cute .size (mX , mode = [0 ]), cute .size (mX , mode = [1 ]))
142152 idX = cute .make_identity_tensor (shape )
143- # slice for CTAs
153+ # Slice for CTAs
144154 gX , gRes , gO , gResO , gRstd , gMean , cX = [
145155 cute .local_tile (mT , tiler_mn , (bidx , cluster_y )) if mT is not None else None
146156 for mT in (mX , mRes , mO , mResO , mRstd , mMean , idX )
@@ -323,7 +333,7 @@ def _rmsnorm_fwd(
323333 """RMSNorm/LayerNorm forward pass.
324334 Args:
325335 x: Input tensor of shape (M, N)
326- weight: Optional weight tensor of shape (N,)
336+ weight: Optional weight tensor of shape (N,) or (H, N) for per-head weight
327337 eps: Small value for numerical stability
328338 is_layernorm: If True, compute LayerNorm instead of RMSNorm
329339 Returns:
@@ -337,7 +347,8 @@ def _rmsnorm_fwd(
337347 if residual is not None :
338348 assert residual .dtype in supported_types , "Residual must be float16, bfloat16, or float32"
339349
340- _ , N = x .shape
350+ N = x .size (- 1 )
351+ per_head = (weight is not None and weight .dim () == 2 ) or (bias is not None and bias .dim () == 2 )
341352 dtype , out_dtype , weight_dtype , bias_dtype , res_dtype , res_out_dtype = [
342353 torch2cute_dtype_map [t .dtype ] if t is not None else None
343354 for t in [x , out , weight , bias , residual , residual_out ]
@@ -353,6 +364,7 @@ def _rmsnorm_fwd(
353364 rstd is not None ,
354365 mean is not None ,
355366 is_layernorm ,
367+ per_head ,
356368 )(x , weight , bias , residual , out , residual_out , rstd , mean , eps )
357369
358370
@@ -372,8 +384,11 @@ def _rmsnorm_fwd_fake(
372384 # See softmax.py _softmax_fwd_fake for why register_fake is needed.
373385 from quack .cache_utils import COMPILE_ONLY
374386
375- if COMPILE_ONLY and not isinstance (x .size (1 ), torch .SymInt ):
376- N = x .size (1 )
387+ if COMPILE_ONLY and not isinstance (x .size (- 1 ), torch .SymInt ):
388+ N = x .size (- 1 )
389+ per_head = (weight is not None and weight .dim () == 2 ) or (
390+ bias is not None and bias .dim () == 2
391+ )
377392 dtype , out_dtype , weight_dtype , bias_dtype , res_dtype , res_out_dtype = [
378393 torch2cute_dtype_map [t .dtype ] if t is not None else None
379394 for t in [x , out , weight , bias , residual , residual_out ]
@@ -389,6 +404,7 @@ def _rmsnorm_fwd_fake(
389404 rstd is not None ,
390405 mean is not None ,
391406 is_layernorm ,
407+ per_head ,
392408 )
393409 _compile_rmsnorm_bwd (
394410 N ,
@@ -400,6 +416,7 @@ def _rmsnorm_fwd_fake(
400416 res_dtype ,
401417 res_out_dtype ,
402418 weight is not None ,
419+ per_head ,
403420 )
404421
405422
@@ -415,16 +432,23 @@ def _compile_rmsnorm_fwd(
415432 has_rstd ,
416433 has_mean ,
417434 is_layernorm ,
435+ per_head ,
418436):
419437 batch_sym = cute .sym_int ()
438+ head_sym = cute .sym_int () if per_head else None
439+ batch_shape = (batch_sym , head_sym ) if per_head else (batch_sym ,)
420440 all_dtypes = [dtype , out_dtype , res_dtype , weight_dtype , bias_dtype , res_out_dtype ]
421441 div = math .gcd (N , * (128 // dt .width for dt in all_dtypes if dt is not None ))
422442 x_cute , out_cute , res_cute , res_out_cute = [
423- fake_tensor (dt , (batch_sym , N ), div ) for dt in [dtype , out_dtype , res_dtype , res_out_dtype ]
443+ fake_tensor (dt , (* batch_shape , N ), div )
444+ for dt in [dtype , out_dtype , res_dtype , res_out_dtype ]
445+ ]
446+ weight_shape = (head_sym , N ) if per_head else (N ,)
447+ weight_cute , bias_cute = [
448+ fake_tensor (dt , weight_shape , div ) for dt in [weight_dtype , bias_dtype ]
424449 ]
425- weight_cute , bias_cute = [fake_tensor (dt , (N ,), div ) for dt in [weight_dtype , bias_dtype ]]
426- rstd_cute = fake_tensor (Float32 , (batch_sym ,)) if has_rstd else None
427- mean_cute = fake_tensor (Float32 , (batch_sym ,)) if has_mean else None
450+ rstd_cute = fake_tensor (Float32 , batch_shape ) if has_rstd else None
451+ mean_cute = fake_tensor (Float32 , batch_shape ) if has_mean else None
428452 return cute .compile (
429453 RMSNorm (dtype , N , is_layernorm = is_layernorm ),
430454 x_cute ,
@@ -456,7 +480,7 @@ def rmsnorm_fwd(
456480 # so that _layer_norm_fwd_impl doesn't have to return them.
457481 out_dtype = x .dtype if out_dtype is None else out_dtype
458482 out = torch .empty_like (x , dtype = out_dtype )
459- rstd = torch .empty (x .shape [0 ], device = x .device , dtype = torch .float32 ) if store_rstd else None
483+ rstd = torch .empty (* x .shape [: - 1 ], device = x .device , dtype = torch .float32 ) if store_rstd else None
460484 if residual is not None :
461485 residual_dtype = residual .dtype
462486 if residual is not None or (residual_dtype is not None and residual_dtype != x .dtype ):
@@ -476,7 +500,7 @@ def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6):
476500 x_f32 = x .float ()
477501 if residual is not None :
478502 residual_f32 = residual .float ()
479- x_f32 += residual_f32
503+ x_f32 = x_f32 + residual_f32
480504 x_norm = x_f32 / (torch .sqrt (torch .mean (x_f32 .square (), dim = - 1 , keepdim = True ) + eps ))
481505 out = x_norm * w if w is not None else x_norm
482506 if bias is not None :
@@ -565,10 +589,11 @@ def __call__(
565589 layout_utils .expand (mW , dim = 0 , size = tiler_mn [0 ]) if const_expr (mW is not None ) else None
566590 )
567591 num_blocks = sm_count
592+ num_heads = mX .shape [1 ] if const_expr (cute .rank (mX ) == 3 ) else 1
568593 self .kernel (
569594 mX , mW , mdO , mdResO , mRstd , mdX , mdW , mdB , mdRes , tiler_mn , tiled_copy , threads_per_row
570595 ).launch (
571- grid = [num_blocks , self .cluster_n , 1 ],
596+ grid = [num_blocks , self .cluster_n , num_heads ],
572597 block = [num_threads , 1 , 1 ],
573598 cluster = [1 , self .cluster_n , 1 ] if self .cluster_n > 1 else None ,
574599 stream = stream ,
@@ -591,11 +616,19 @@ def kernel(
591616 threads_per_row : cutlass .Constexpr [int ],
592617 ):
593618 tidx , _ , _ = cute .arch .thread_idx ()
594- bidx_start , _ , _ = cute .arch .block_idx ()
619+ bidx_start , _ , bidz = cute .arch .block_idx ()
595620 gdim , _ , _ = cute .arch .grid_dim ()
596621 cluster_y = const_expr (0 ) if const_expr (self .cluster_n == 1 ) else cute .arch .block_idx ()[1 ]
597622 tv_layout = tiled_copy .layout_tv_tiled
598623
624+ # Slice per head
625+ if const_expr (cute .rank (mX ) == 3 ):
626+ mX , mW , mdO , mdResO , mdX , mdW , mdB , mdRes = [
627+ mT [None , bidz , None ] if const_expr (mT is not None ) else None
628+ for mT in (mX , mW , mdO , mdResO , mdX , mdW , mdB , mdRes )
629+ ]
630+ mRstd = mRstd [None , bidz ]
631+
599632 shape = mX .shape
600633 M , N = shape [0 ], shape [1 ]
601634 is_even_N = const_expr (shape [1 ] == tiler_mn [1 ] * self .cluster_n )
@@ -895,22 +928,21 @@ def _rmsnorm_bwd(
895928) -> None :
896929 """RMSNorm backward pass.
897930 Args:
898- x: Input tensor of shape (M, N)
899- weight: Optional weight tensor of shape (N,)
900- dout: Upstream gradients tensor of shape (M, N)
901- rstd: Reciprocal standard deviation tensor of shape (M,)
931+ x: Input tensor of shape (M, N) or (M, H, N) for per-head
932+ weight: Optional weight tensor of shape (N,) or (H, N) for per-head
933+ dout: Upstream gradients tensor of shape (M, N) or (M, H, N)
934+ rstd: Reciprocal standard deviation tensor of shape (M,) or (M, H)
902935 Returns:
903936 Tuple of (dx, dw) where:
904937 - dx: Input gradients tensor of same shape as x
905938 - dw: Weight gradients tensor of same shape as weight (or None if weight is None)
906939 """
907- assert x .dim () == 2 , "Input must be 2D"
940+ assert x .dim () in ( 2 , 3 ), "Input must be 2D or 3D "
908941 assert x .is_cuda , "Input tensor must be on CUDA device"
909942 supported_types = {torch .float16 , torch .bfloat16 , torch .float32 }
910943 assert x .dtype in supported_types , "Unsupported dtype"
944+ per_head = x .dim () == 3
911945 if weight is not None :
912- assert weight .dim () == 1 , "Weight must be 1D"
913- assert x .shape [- 1 ] == weight .shape [0 ], "Last dimension of input must match weight dimension"
914946 assert weight .is_cuda , "Weight tensor must be on CUDA device"
915947 assert weight .dtype in supported_types , "Weight must be float32, float16 or bfloat16"
916948 if dresidual_out is not None :
@@ -924,7 +956,7 @@ def _rmsnorm_bwd(
924956 assert dresidual .is_cuda
925957 assert dresidual .dtype in supported_types , "Residual must be float16, bfloat16, or float32"
926958
927- N = x .size (1 )
959+ N = x .size (- 1 )
928960 if dw_partial is None and db_partial is None :
929961 assert sm_count is not None
930962 else :
@@ -943,6 +975,7 @@ def _rmsnorm_bwd(
943975 dres_dtype ,
944976 dres_out_dtype ,
945977 dw_partial is not None ,
978+ per_head ,
946979 )(x , weight , dout , dresidual_out , rstd , dx , dw_partial , dresidual , db_partial , sm_count )
947980
948981
@@ -962,8 +995,9 @@ def _rmsnorm_bwd_fake(
962995 # See softmax.py _softmax_fwd_fake for why register_fake is needed.
963996 from quack .cache_utils import COMPILE_ONLY
964997
965- if COMPILE_ONLY and not isinstance (x .size (1 ), torch .SymInt ):
966- N = x .size (1 )
998+ if COMPILE_ONLY and not isinstance (x .size (- 1 ), torch .SymInt ):
999+ N = x .size (- 1 )
1000+ per_head = x .dim () == 3
9671001 if dw_partial is None and db_partial is None and sm_count is None :
9681002 return
9691003 dtype , dout_dtype , dx_dtype , weight_dtype , dres_dtype , dres_out_dtype = [
@@ -980,6 +1014,7 @@ def _rmsnorm_bwd_fake(
9801014 dres_dtype ,
9811015 dres_out_dtype ,
9821016 dw_partial is not None ,
1017+ per_head ,
9831018 )
9841019
9851020
@@ -994,18 +1029,23 @@ def _compile_rmsnorm_bwd(
9941029 dres_dtype ,
9951030 dres_out_dtype ,
9961031 has_dw_partial ,
1032+ per_head = False ,
9971033):
9981034 batch_sym , batch_partial_sym = cute .sym_int (), cute .sym_int ()
1035+ head_sym = cute .sym_int () if per_head else None
1036+ batch_shape = (batch_sym , head_sym ) if per_head else (batch_sym ,)
9991037 all_dtypes = [dtype , dout_dtype , dx_dtype , dres_dtype , dres_out_dtype ]
10001038 div = math .gcd (N , * (128 // dt .width for dt in all_dtypes if dt is not None ))
10011039 x_cute , dout_cute , dx_cute , dres_out_cute , dres_cute = [
1002- fake_tensor (dt , (batch_sym , N ), div )
1040+ fake_tensor (dt , (* batch_shape , N ), div )
10031041 for dt in [dtype , dout_dtype , dx_dtype , dres_out_dtype , dres_dtype ]
10041042 ]
1005- weight_cute = fake_tensor (weight_dtype , (N ,), div )
1006- rstd_cute = fake_tensor (Float32 , (batch_sym ,))
1007- dw_partial_cute = fake_tensor (Float32 , (batch_partial_sym , N ), div ) if has_dw_partial else None
1008- db_partial_cute = fake_tensor (Float32 , (batch_partial_sym , N ), div ) if has_db_partial else None
1043+ weight_shape = (head_sym , N ) if per_head else (N ,)
1044+ weight_cute = fake_tensor (weight_dtype , weight_shape , div )
1045+ rstd_cute = fake_tensor (Float32 , batch_shape )
1046+ dw_shape = (batch_partial_sym , head_sym , N ) if per_head else (batch_partial_sym , N )
1047+ dw_partial_cute = fake_tensor (Float32 , dw_shape , div ) if has_dw_partial else None
1048+ db_partial_cute = fake_tensor (Float32 , dw_shape , div ) if has_db_partial else None
10091049 return cute .compile (
10101050 RMSNormBackward (dtype , N ),
10111051 x_cute ,
@@ -1033,19 +1073,27 @@ def rmsnorm_bwd(
10331073 has_residual : bool = False ,
10341074) -> Tuple [Tensor , Optional [Tensor ], Optional [Tensor ], Optional [Tensor ]]:
10351075 device = x .device
1036- N = x .size (1 )
1076+ N = x .size (- 1 )
1077+ per_head = x .dim () == 3
10371078 dx = torch .empty_like (x )
10381079 if dresidual_out is not None and dresidual_out .dtype != dx .dtype :
10391080 dresidual = torch .empty_like (x , dtype = dresidual_out .dtype )
10401081 else :
10411082 dresidual = None
10421083 sm_count = _get_sm_count (N , device )
1084+ if per_head :
1085+ H = x .size (1 )
1086+ sm_count = max (round (sm_count / H ), 1 )
1087+ else :
1088+ H = None
10431089 if weight is not None :
10441090 # Always store partial gradients in fp32 for numerical accuracy
1045- dw_partial = torch .empty (sm_count , N , device = device , dtype = torch .float32 )
1091+ dw_shape = (sm_count , H , N ) if per_head else (sm_count , N )
1092+ dw_partial = torch .empty (dw_shape , device = device , dtype = torch .float32 )
10461093 else :
10471094 dw_partial = None
1048- db_partial = torch .empty (sm_count , N , device = device , dtype = torch .float32 ) if has_bias else None
1095+ db_shape = (sm_count , H , N ) if per_head else (sm_count , N )
1096+ db_partial = torch .empty (db_shape , device = device , dtype = torch .float32 ) if has_bias else None
10491097
10501098 _rmsnorm_bwd (
10511099 x , weight , dout , rstd , dx , dw_partial , db_partial , dresidual_out , dresidual , sm_count
@@ -1074,10 +1122,14 @@ def forward(
10741122 prenorm = False ,
10751123 ):
10761124 x_shape_og = x .shape
1125+ per_head = (weight is not None and weight .dim () == 2 ) or (
1126+ bias is not None and bias .dim () == 2
1127+ )
1128+ last_shape = x_shape_og [- 1 :] if not per_head else x_shape_og [- 2 :]
10771129 # Flatten input, ensuring last dim is contiguous
1078- x = _ensure_contiguous (x .reshape (- 1 , x . shape [ - 1 ] ))
1130+ x = _ensure_contiguous (x .reshape (- 1 , * last_shape ))
10791131 if residual is not None :
1080- residual = _ensure_contiguous (residual .reshape (- 1 , residual . shape [ - 1 ] ))
1132+ residual = _ensure_contiguous (residual .reshape (- 1 , * last_shape ))
10811133 need_grad = any (ctx .needs_input_grad [:3 ])
10821134 out , residual_out , rstd = rmsnorm_fwd (
10831135 x ,
@@ -1091,6 +1143,7 @@ def forward(
10911143 )
10921144 ctx .save_for_backward (x if residual is None else residual_out , weight , rstd )
10931145 ctx .has_bias = bias is not None
1146+ ctx .per_head = per_head
10941147 ctx .eps = eps
10951148 ctx .x_shape_og = x_shape_og
10961149 ctx .residual_dtype = residual .dtype if residual is not None else None
@@ -1104,14 +1157,16 @@ def forward(
11041157 def backward (ctx , dout , * args ):
11051158 x , weight , rstd = ctx .saved_tensors
11061159 has_bias = ctx .has_bias
1160+ per_head = ctx .per_head
1161+ x_shape_og = ctx .x_shape_og
1162+ last_shape = x_shape_og [- 2 :] if per_head else x_shape_og [- 1 :]
11071163 if ctx .prenorm and ctx .residual_dtype is not None :
11081164 dresidual_out = args [0 ]
1109- dresidual_out = _ensure_contiguous (dresidual_out .reshape (- 1 , dresidual_out . shape [ - 1 ] ))
1165+ dresidual_out = _ensure_contiguous (dresidual_out .reshape (- 1 , * last_shape ))
11101166 else :
11111167 dresidual_out = None
1112- x_shape_og = ctx .x_shape_og
1113- # Reshape dout to match the flattened shape used in forward
1114- dout = _ensure_contiguous (dout .reshape (- 1 , dout .shape [- 1 ]))
1168+ # Reshape dout to match the shape used in forward
1169+ dout = _ensure_contiguous (dout .reshape (- 1 , * last_shape ))
11151170 dx , dw , db , dresidual = rmsnorm_bwd (
11161171 x ,
11171172 weight ,
0 commit comments