Skip to content

Commit 5c80691

Browse files
authored
[Dlight] Enhance vectorization loading weight for gemv (#16878)
* [Dlight] Enhance vectorization loading weight for gemv * Update gemv.py
1 parent 0a3fe22 commit 5c80691

2 files changed

Lines changed: 38 additions & 37 deletions

File tree

python/tvm/dlight/gpu/gemv.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""A rule for GEMV and DecodeGEMV."""
18-
import re
1918
from functools import reduce
2019
from typing import List, Optional, Union
2120

@@ -56,10 +55,9 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):
5655

5756

5857
def get_bytes(dtype: Union[DataType, str]) -> int:
59-
num = re.findall(r"\d+", dtype)
60-
if len(num) != 1:
61-
raise ValueError(f"Cannot get bytes from {dtype}")
62-
return int(num[0]) // 8
58+
if isinstance(dtype, str):
59+
dtype = DataType(dtype)
60+
return dtype.bits * dtype.lanes // 8
6361

6462

6563
def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]:
@@ -297,10 +295,11 @@ def apply(
297295
Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local")
298296
sch.compute_at(Aq_local, r, preserve_unit_loops=True)
299297
s_local, r_local = sch.get_loops(block=Aq_local)[-2:]
300-
s_local, vec_load = sch.split(
301-
s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True
298+
fused_load = sch.fuse(s_local, r_local)
299+
aq_vec_len = max(1, VEC_LOAD // get_bytes(sch.get(Aq_local).reads[0].buffer.dtype))
300+
fused_load, vec_load = sch.split(
301+
fused_load, factors=[None, aq_vec_len], preserve_unit_iters=True
302302
)
303-
sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1
304303
sch.vectorize(vec_load)
305304

306305
# load vector into shared memory, shape should be the whole vector
@@ -442,10 +441,12 @@ def apply(
442441

443442
TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
444443
SUPPORT_WARP_SHUFFLE = False
444+
VEC_LOAD = 1
445445
if target.kind.name == "cuda":
446446
VEC_C = 4
447447
LOAD_V_SHARED = True
448448
LOAD_V_VEC = 8
449+
VEC_LOAD = 4
449450
UNROLL = 256
450451
SUPPORT_WARP_SHUFFLE = True
451452
if isinstance(len_S, int):
@@ -522,7 +523,6 @@ def apply(
522523
else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1),
523524
)
524525
VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C)
525-
VEC_LOAD = 1
526526

527527
return apply(
528528
sch,

tests/python/dlight/test_gpu_gemv.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,13 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p
120120
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1])
121121
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = T.float16(0)
122122
for ax2_fused_u_fused_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
123-
for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2):
124-
for ax2_1 in T.vectorized(1):
123+
for ax0, ax1, ax2_ax3_fused_0 in T.grid(1, 1, 1):
124+
for ax2_ax3_fused_1 in T.vectorized(2):
125125
with T.block("lv1638_local"):
126126
v0 = T.axis.spatial(1, ax0)
127127
v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1)
128-
v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1)
129-
v3 = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3)
128+
v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
129+
v3 = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_ax3_fused_0 * 2 + ax2_ax3_fused_1)
130130
T.reads(lv1638[v0, v1, v2, v3])
131131
T.writes(lv1638_local[v0, v1, v2, v3])
132132
lv1638_local[v0, v1, v2, v3] = lv1638[v0, v1, v2, v3]
@@ -224,11 +224,11 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12
224224
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
225225
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
226226
for ax1_0_fused_ax1_1_fused_0 in T.serial(32, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
227-
for ax0_0, ax1 in T.grid(1, 1):
227+
for ax0_ax1_fused in T.serial(1):
228228
for ax0_1 in T.vectorized(1):
229229
with T.block("lv571_local"):
230-
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
231-
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
230+
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
231+
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
232232
T.reads(lv571[v0, v1])
233233
T.writes(lv571_local[v0, v1])
234234
lv571_local[v0, v1] = lv571[v0, v1]
@@ -332,11 +332,11 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12
332332
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
333333
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
334334
for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
335-
for ax0_0, ax1 in T.grid(1, 1):
336-
for ax0_1 in T.vectorized(1):
335+
for ax0_ax1_fused_0 in range(1):
336+
for ax0_ax1_fused_1 in T.vectorized(1):
337337
with T.block("lv571_local"):
338-
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
339-
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
338+
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
339+
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
340340
T.reads(lv571[v0, v1])
341341
T.writes(lv571_local[v0, v1])
342342
lv571_local[v0, v1] = lv571[v0, v1]
@@ -448,11 +448,11 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12
448448
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
449449
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
450450
for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
451-
for ax0_0, ax1 in T.grid(1, 1):
452-
for ax0_1 in T.vectorized(1):
451+
for ax0_ax1_fused_0 in range(1):
452+
for ax0_ax1_fused_1 in T.vectorized(1):
453453
with T.block("lv771_local"):
454-
v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
455-
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
454+
v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
455+
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
456456
T.reads(lv771[v0, v1])
457457
T.writes(lv771_local[v0, v1])
458458
lv771_local[v0, v1] = lv771[v0, v1]
@@ -572,11 +572,11 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T
572572
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
573573
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0)
574574
for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(43), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
575-
for ax0_0, ax1 in T.grid(T.int64(1), T.int64(1)):
576-
for ax0_1 in T.vectorized(T.int64(1)):
575+
for ax0_ax1_fused_0 in range(T.int64(1)):
576+
for ax0_ax1_fused_1 in T.vectorized(T.int64(1)):
577577
with T.block("lv575_local"):
578-
v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
579-
v1 = T.axis.spatial(T.int64(1376), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
578+
v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1)
579+
v1 = T.axis.spatial(T.int64(1376), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
580580
T.reads(lv575[v0, v1])
581581
T.writes(lv575_local[v0, v1])
582582
lv575_local[v0, v1] = lv575[v0, v1]
@@ -942,15 +942,16 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f
942942
T.writes(o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0])
943943
o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0] = T.float16(0)
944944
for ax1_fused_u_fused_0 in T.serial(32, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
945-
for ax0, ax1_0, ax2 in T.grid(1, 1, 8):
946-
for ax1_1 in T.vectorized(1):
947-
with T.block("w_local"):
948-
v0 = T.axis.spatial(1, ax0)
949-
v1 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax1_0 + ax1_1)
950-
v2 = T.axis.spatial(4096, ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 + ax2)
951-
T.reads(w[indptr[v_expert_id_o] + v0, v1, v2])
952-
T.writes(w_local[v0, v1, v2])
953-
w_local[v0, v1, v2] = w[indptr[v_expert_id_o] + v0, v1, v2]
945+
for ax0 in range(1):
946+
for ax1_ax2_fused_0 in range(8):
947+
for ax1_ax2_fused_1 in T.vectorized(1):
948+
with T.block("w_local"):
949+
v0 = T.axis.spatial(1, ax0)
950+
v1 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
951+
v2 = T.axis.spatial(4096, ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 + ax1_ax2_fused_0 + ax1_ax2_fused_1)
952+
T.reads(w[indptr[v_expert_id_o] + v0, v1, v2])
953+
T.writes(w_local[v0, v1, v2])
954+
w_local[v0, v1, v2] = w[indptr[v_expert_id_o] + v0, v1, v2]
954955
for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(1, 8):
955956
for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(1):
956957
with T.block("gemv_rf_update"):

0 commit comments

Comments
 (0)