@@ -120,13 +120,13 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p
120120 T .writes (var_NT_matmul_intermediate_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused , 0 , v0 , 0 , v1 ])
121121 var_NT_matmul_intermediate_rf_local [vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused , 0 , v0 , 0 , v1 ] = T .float16 (0 )
122122 for ax2_fused_u_fused_0 in T .serial (1 , annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 }):
123- for ax0 , ax1 , ax2_0 , ax3 in T .grid (1 , 1 , 1 , 2 ):
124- for ax2_1 in T .vectorized (1 ):
123+ for ax0 , ax1 , ax2_ax3_fused_0 in T .grid (1 , 1 , 1 ):
124+ for ax2_ax3_fused_1 in T .vectorized (2 ):
125125 with T .block ("lv1638_local" ):
126126 v0 = T .axis .spatial (1 , ax0 )
127127 v1 = T .axis .spatial (32 , ax0_fused_ax1_fused_fused_0 // n + ax1 )
128- v2 = T .axis .spatial (n , ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1 )
129- v3 = T .axis .spatial (128 , ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3 )
128+ v2 = T .axis .spatial (n , ax0_fused_ax1_fused_fused_0 % n )
129+ v3 = T .axis .spatial (128 , ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_ax3_fused_0 * 2 + ax2_ax3_fused_1 )
130130 T .reads (lv1638 [v0 , v1 , v2 , v3 ])
131131 T .writes (lv1638_local [v0 , v1 , v2 , v3 ])
132132 lv1638_local [v0 , v1 , v2 , v3 ] = lv1638 [v0 , v1 , v2 , v3 ]
@@ -224,11 +224,11 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12
224224 T .writes (var_NT_matmul_intermediate_rf_local [vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused , 0 , 0 , v0 ])
225225 var_NT_matmul_intermediate_rf_local [vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused , 0 , 0 , v0 ] = T .float16 (0 )
226226 for ax1_0_fused_ax1_1_fused_0 in T .serial (32 , annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 }):
227- for ax0_0 , ax1 in T .grid ( 1 , 1 ):
227+ for ax0_ax1_fused in T .serial ( 1 ):
228228 for ax0_1 in T .vectorized (1 ):
229229 with T .block ("lv571_local" ):
230- v0 = T .axis .spatial (22016 , u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1 )
231- v1 = T .axis .spatial (512 , ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1 )
230+ v0 = T .axis .spatial (22016 , u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 )
231+ v1 = T .axis .spatial (512 , ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 )
232232 T .reads (lv571 [v0 , v1 ])
233233 T .writes (lv571_local [v0 , v1 ])
234234 lv571_local [v0 , v1 ] = lv571 [v0 , v1 ]
@@ -332,11 +332,11 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12
332332 T .writes (var_NT_matmul_intermediate_rf_local [vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused , 0 , 0 , v0 ])
333333 var_NT_matmul_intermediate_rf_local [vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused , 0 , 0 , v0 ] = T .float16 (0 )
334334 for ax1_0_fused_ax1_1_fused_0 in T .serial (8 , annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 }):
335- for ax0_0 , ax1 in T . grid ( 1 , 1 ):
336- for ax0_1 in T .vectorized (1 ):
335+ for ax0_ax1_fused_0 in range ( 1 ):
336+ for ax0_ax1_fused_1 in T .vectorized (1 ):
337337 with T .block ("lv571_local" ):
338- v0 = T .axis .spatial (22016 , u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1 )
339- v1 = T .axis .spatial (512 , ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1 )
338+ v0 = T .axis .spatial (22016 , u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 )
339+ v1 = T .axis .spatial (512 , ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 )
340340 T .reads (lv571 [v0 , v1 ])
341341 T .writes (lv571_local [v0 , v1 ])
342342 lv571_local [v0 , v1 ] = lv571 [v0 , v1 ]
@@ -448,11 +448,11 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12
448448 T .writes (var_NT_matmul_intermediate_rf_local [vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused , 0 , 0 , v0 ])
449449 var_NT_matmul_intermediate_rf_local [vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused , 0 , 0 , v0 ] = T .float16 (0 )
450450 for ax1_0_fused_ax1_1_fused_0 in T .serial (8 , annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 }):
451- for ax0_0 , ax1 in T . grid ( 1 , 1 ):
452- for ax0_1 in T .vectorized (1 ):
451+ for ax0_ax1_fused_0 in range ( 1 ):
452+ for ax0_ax1_fused_1 in T .vectorized (1 ):
453453 with T .block ("lv771_local" ):
454- v0 = T .axis .spatial (32000 , u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1 )
455- v1 = T .axis .spatial (512 , ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1 )
454+ v0 = T .axis .spatial (32000 , u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 )
455+ v1 = T .axis .spatial (512 , ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 )
456456 T .reads (lv771 [v0 , v1 ])
457457 T .writes (lv771_local [v0 , v1 ])
458458 lv771_local [v0 , v1 ] = lv771 [v0 , v1 ]
@@ -572,11 +572,11 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T
572572 T .writes (var_NT_matmul_intermediate_rf_local [vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused , T .int64 (0 ), T .int64 (0 ), v0 ])
573573 var_NT_matmul_intermediate_rf_local [vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused , T .int64 (0 ), T .int64 (0 ), v0 ] = T .float16 (0 )
574574 for ax1_0_fused_ax1_1_fused_0 in T .serial (T .int64 (43 ), annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 }):
575- for ax0_0 , ax1 in T . grid ( T . int64 ( 1 ), T .int64 (1 )):
576- for ax0_1 in T .vectorized (T .int64 (1 )):
575+ for ax0_ax1_fused_0 in range ( T .int64 (1 )):
576+ for ax0_ax1_fused_1 in T .vectorized (T .int64 (1 )):
577577 with T .block ("lv575_local" ):
578- v0 = T .axis .spatial (T .int64 (4096 ), u_fused_ax0_fused_fused_0 * T .int64 (16 ) + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1 )
579- v1 = T .axis .spatial (T .int64 (1376 ), ax1_0_fused_ax1_1_fused_0 * T .int64 (32 ) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1 )
578+ v0 = T .axis .spatial (T .int64 (4096 ), u_fused_ax0_fused_fused_0 * T .int64 (16 ) + u_fused_ax0_fused_fused_1 )
579+ v1 = T .axis .spatial (T .int64 (1376 ), ax1_0_fused_ax1_1_fused_0 * T .int64 (32 ) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 )
580580 T .reads (lv575 [v0 , v1 ])
581581 T .writes (lv575_local [v0 , v1 ])
582582 lv575_local [v0 , v1 ] = lv575 [v0 , v1 ]
@@ -942,15 +942,16 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f
942942 T .writes (o_rf_local [vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused , v_expert_id_o , v0 ])
943943 o_rf_local [vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused , v_expert_id_o , v0 ] = T .float16 (0 )
944944 for ax1_fused_u_fused_0 in T .serial (32 , annotations = {"pragma_auto_unroll_max_step" : 256 , "pragma_unroll_explicit" : 1 }):
945- for ax0 , ax1_0 , ax2 in T .grid (1 , 1 , 8 ):
946- for ax1_1 in T .vectorized (1 ):
947- with T .block ("w_local" ):
948- v0 = T .axis .spatial (1 , ax0 )
949- v1 = T .axis .spatial (16384 , u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax1_0 + ax1_1 )
950- v2 = T .axis .spatial (4096 , ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 + ax2 )
951- T .reads (w [indptr [v_expert_id_o ] + v0 , v1 , v2 ])
952- T .writes (w_local [v0 , v1 , v2 ])
953- w_local [v0 , v1 , v2 ] = w [indptr [v_expert_id_o ] + v0 , v1 , v2 ]
945+ for ax0 in range (1 ):
946+ for ax1_ax2_fused_0 in range (8 ):
947+ for ax1_ax2_fused_1 in T .vectorized (1 ):
948+ with T .block ("w_local" ):
949+ v0 = T .axis .spatial (1 , ax0 )
950+ v1 = T .axis .spatial (16384 , u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 )
951+ v2 = T .axis .spatial (4096 , ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 + ax1_ax2_fused_0 + ax1_ax2_fused_1 )
952+ T .reads (w [indptr [v_expert_id_o ] + v0 , v1 , v2 ])
953+ T .writes (w_local [v0 , v1 , v2 ])
954+ w_local [v0 , v1 , v2 ] = w [indptr [v_expert_id_o ] + v0 , v1 , v2 ]
954955 for u_fused_ax0_fused_fused_2 , ax1_fused_u_fused_2 in T .grid (1 , 8 ):
955956 for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T .vectorized (1 ):
956957 with T .block ("gemv_rf_update" ):
0 commit comments