@@ -50,6 +50,7 @@ def parse_arguments() -> argparse.Namespace:
5050 parser .add_argument ("--varlen_k" , action = "store_true" , help = "Variable length K dimension" )
5151 parser .add_argument ("--gather_A" , action = "store_true" , help = "Gather A" )
5252 parser .add_argument ("--use_tma_gather" , action = "store_true" , help = "Use TMA gather4 for A" )
53+ parser .add_argument ("--max_swizzle_size" , type = int , default = 8 , help = "Max swizzle size" )
5354 parser .add_argument ("--skip_ref_check" , action = "store_true" , help = "Skip reference checking" )
5455
5556 args = parser .parse_args ()
@@ -109,7 +110,12 @@ def run(args):
109110 total_k = k * l
110111 cu_seqlens_k = torch .arange (0 , l + 1 , dtype = torch .int32 , device = device ) * k
111112 # m-major A, n-major B for varlen_k
112- A = torch .randn (total_k , m , dtype = torch .bfloat16 , device = device ).T
113+ if gather_A :
114+ larger_k = total_k * 2
115+ A = torch .randn (larger_k , m , dtype = torch .bfloat16 , device = device ).T
116+ A_idx = torch .randperm (larger_k , dtype = torch .int32 , device = device )[:total_k ]
117+ else :
118+ A = torch .randn (total_k , m , dtype = torch .bfloat16 , device = device ).T
113119 B = torch .randn (total_k , n , dtype = torch .bfloat16 , device = device ).T
114120 D = torch .empty (l , m , n , dtype = torch .bfloat16 , device = device )
115121 else :
@@ -127,6 +133,7 @@ def fn():
127133 pingpong = args .pingpong ,
128134 persistent = persistent ,
129135 is_dynamic_persistent = args .dynamic_persistent ,
136+ max_swizzle_size = args .max_swizzle_size ,
130137 cu_seqlens_m = cu_seqlens_m ,
131138 cu_seqlens_k = cu_seqlens_k ,
132139 A_idx = A_idx ,
@@ -146,7 +153,8 @@ def fn():
146153 ])
147154 elif varlen_k :
148155 ref = torch .stack ([
149- A [:, cu_seqlens_k [i ]:cu_seqlens_k [i + 1 ]] @
156+ (A [:, A_idx [cu_seqlens_k [i ]:cu_seqlens_k [i + 1 ]]] if gather_A
157+ else A [:, cu_seqlens_k [i ]:cu_seqlens_k [i + 1 ]]) @
150158 B [:, cu_seqlens_k [i ]:cu_seqlens_k [i + 1 ]].T
151159 for i in range (l )
152160 ])
0 commit comments