Skip to content

Commit bbdc7ba

Browse files
WANDY666wangzaijun
andauthored
Merge q,kv (#1199)
Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
1 parent d3397d7 commit bbdc7ba

File tree

10 files changed

+380
-4
lines changed

10 files changed

+380
-4
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
ROWMMWeight,
55
KVROWNMMWeight,
66
ROWBMMWeight,
7+
QKVROWNMMWeight,
78
COLMMWeight,
89
)
910
from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .mm_weight import (
22
MMWeightTpl,
33
)
4-
from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight
4+
from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight, QKVROWNMMWeight
55
from .colmm_weight import COLMMWeight

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,50 @@ def _get_tp_padded_head_num(self, head_num: int):
9292
)
9393

9494

95+
class QKVROWNMMWeight(MMWeightTpl):
96+
def __init__(
97+
self,
98+
in_dim: int,
99+
q_head_num: int,
100+
kv_head_num: int,
101+
head_dim: int,
102+
weight_names: Union[str, List[str]],
103+
data_type: torch.dtype,
104+
bias_names: Optional[Union[str, List[str]]] = None,
105+
quant_method: QuantizationMethod = None,
106+
tp_rank: int = None,
107+
tp_world_size: int = None,
108+
) -> None:
109+
self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp()
110+
self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size()
111+
self.repeat_times = 1
112+
assert q_head_num % self.tp_world_size_ == 0, (
113+
f"q_head_num must be divisible by tp_world_size_, " f"but found: {q_head_num} % {self.tp_world_size_}"
114+
)
115+
assert kv_head_num % self.tp_world_size_ == 0, (
116+
f"kv_head_num must be divisible by tp_world_size_" f"but found: {kv_head_num} % {self.tp_world_size_}"
117+
)
118+
q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim
119+
kv_hidden_size = (kv_head_num // self.tp_world_size_) * head_dim
120+
out_dims = [q_hidden_size, kv_hidden_size, kv_hidden_size]
121+
super().__init__(
122+
in_dim=in_dim,
123+
out_dims=out_dims,
124+
weight_names=weight_names,
125+
data_type=data_type,
126+
bias_names=bias_names,
127+
quant_method=quant_method,
128+
tp_rank=self.tp_rank_,
129+
tp_world_size=self.tp_world_size_,
130+
)
131+
self.param_slicer = get_row_slice_mixin(
132+
self.quant_method.method_name,
133+
tp_rank=self.tp_rank_,
134+
tp_world_size=self.tp_world_size_,
135+
repeat_times=self.repeat_times,
136+
)
137+
138+
95139
class ROWBMMWeight(BMMWeightTpl):
96140
def __init__(
97141
self,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{
2+
"1024": {
3+
"BLOCK_SIZE_K": 64,
4+
"BLOCK_SIZE_M": 16,
5+
"BLOCK_SIZE_N": 128,
6+
"GROUP_SIZE_M": 64,
7+
"NEED_TRANS": false,
8+
"num_stages": 2,
9+
"num_warps": 4
10+
},
11+
"128": {
12+
"BLOCK_SIZE_K": 64,
13+
"BLOCK_SIZE_M": 16,
14+
"BLOCK_SIZE_N": 128,
15+
"GROUP_SIZE_M": 16,
16+
"NEED_TRANS": false,
17+
"num_stages": 3,
18+
"num_warps": 4
19+
},
20+
"2048": {
21+
"BLOCK_SIZE_K": 32,
22+
"BLOCK_SIZE_M": 32,
23+
"BLOCK_SIZE_N": 128,
24+
"GROUP_SIZE_M": 16,
25+
"NEED_TRANS": false,
26+
"num_stages": 3,
27+
"num_warps": 4
28+
},
29+
"256": {
30+
"BLOCK_SIZE_K": 64,
31+
"BLOCK_SIZE_M": 16,
32+
"BLOCK_SIZE_N": 128,
33+
"GROUP_SIZE_M": 1,
34+
"NEED_TRANS": false,
35+
"num_stages": 2,
36+
"num_warps": 4
37+
},
38+
"512": {
39+
"BLOCK_SIZE_K": 64,
40+
"BLOCK_SIZE_M": 16,
41+
"BLOCK_SIZE_N": 128,
42+
"GROUP_SIZE_M": 1,
43+
"NEED_TRANS": false,
44+
"num_stages": 4,
45+
"num_warps": 4
46+
},
47+
"64": {
48+
"BLOCK_SIZE_K": 64,
49+
"BLOCK_SIZE_M": 16,
50+
"BLOCK_SIZE_N": 128,
51+
"GROUP_SIZE_M": 1,
52+
"NEED_TRANS": false,
53+
"num_stages": 2,
54+
"num_warps": 4
55+
},
56+
"8": {
57+
"BLOCK_SIZE_K": 32,
58+
"BLOCK_SIZE_M": 16,
59+
"BLOCK_SIZE_N": 64,
60+
"GROUP_SIZE_M": 1,
61+
"NEED_TRANS": false,
62+
"num_stages": 2,
63+
"num_warps": 4
64+
},
65+
"800": {
66+
"BLOCK_SIZE_K": 64,
67+
"BLOCK_SIZE_M": 16,
68+
"BLOCK_SIZE_N": 128,
69+
"GROUP_SIZE_M": 32,
70+
"NEED_TRANS": false,
71+
"num_stages": 2,
72+
"num_warps": 4
73+
},
74+
"8192": {
75+
"BLOCK_SIZE_K": 64,
76+
"BLOCK_SIZE_M": 64,
77+
"BLOCK_SIZE_N": 128,
78+
"GROUP_SIZE_M": 32,
79+
"NEED_TRANS": false,
80+
"num_stages": 2,
81+
"num_warps": 4
82+
}
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_K": 128,
4+
"BLOCK_SIZE_M": 16,
5+
"BLOCK_SIZE_N": 64,
6+
"GROUP_SIZE_M": 1,
7+
"NEED_TRANS": false,
8+
"num_stages": 4,
9+
"num_warps": 4
10+
},
11+
"100": {
12+
"BLOCK_SIZE_K": 128,
13+
"BLOCK_SIZE_M": 16,
14+
"BLOCK_SIZE_N": 128,
15+
"GROUP_SIZE_M": 1,
16+
"NEED_TRANS": false,
17+
"num_stages": 3,
18+
"num_warps": 4
19+
},
20+
"1024": {
21+
"BLOCK_SIZE_K": 32,
22+
"BLOCK_SIZE_M": 64,
23+
"BLOCK_SIZE_N": 128,
24+
"GROUP_SIZE_M": 64,
25+
"NEED_TRANS": false,
26+
"num_stages": 3,
27+
"num_warps": 4
28+
},
29+
"128": {
30+
"BLOCK_SIZE_K": 128,
31+
"BLOCK_SIZE_M": 16,
32+
"BLOCK_SIZE_N": 128,
33+
"GROUP_SIZE_M": 32,
34+
"NEED_TRANS": false,
35+
"num_stages": 2,
36+
"num_warps": 8
37+
},
38+
"16": {
39+
"BLOCK_SIZE_K": 64,
40+
"BLOCK_SIZE_M": 16,
41+
"BLOCK_SIZE_N": 128,
42+
"GROUP_SIZE_M": 1,
43+
"NEED_TRANS": false,
44+
"num_stages": 3,
45+
"num_warps": 4
46+
},
47+
"256": {
48+
"BLOCK_SIZE_K": 128,
49+
"BLOCK_SIZE_M": 32,
50+
"BLOCK_SIZE_N": 128,
51+
"GROUP_SIZE_M": 16,
52+
"NEED_TRANS": false,
53+
"num_stages": 2,
54+
"num_warps": 4
55+
},
56+
"32": {
57+
"BLOCK_SIZE_K": 128,
58+
"BLOCK_SIZE_M": 16,
59+
"BLOCK_SIZE_N": 64,
60+
"GROUP_SIZE_M": 16,
61+
"NEED_TRANS": false,
62+
"num_stages": 3,
63+
"num_warps": 4
64+
},
65+
"64": {
66+
"BLOCK_SIZE_K": 128,
67+
"BLOCK_SIZE_M": 16,
68+
"BLOCK_SIZE_N": 128,
69+
"GROUP_SIZE_M": 32,
70+
"NEED_TRANS": false,
71+
"num_stages": 2,
72+
"num_warps": 4
73+
},
74+
"8": {
75+
"BLOCK_SIZE_K": 128,
76+
"BLOCK_SIZE_M": 16,
77+
"BLOCK_SIZE_N": 128,
78+
"GROUP_SIZE_M": 32,
79+
"NEED_TRANS": false,
80+
"num_stages": 3,
81+
"num_warps": 8
82+
}
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE": 256,
4+
"num_warps": 4
5+
},
6+
"100": {
7+
"BLOCK_SIZE": 128,
8+
"num_warps": 8
9+
},
10+
"1024": {
11+
"BLOCK_SIZE": 256,
12+
"num_warps": 4
13+
},
14+
"128": {
15+
"BLOCK_SIZE": 256,
16+
"num_warps": 8
17+
},
18+
"16": {
19+
"BLOCK_SIZE": 128,
20+
"num_warps": 8
21+
},
22+
"256": {
23+
"BLOCK_SIZE": 128,
24+
"num_warps": 8
25+
},
26+
"32": {
27+
"BLOCK_SIZE": 128,
28+
"num_warps": 8
29+
},
30+
"64": {
31+
"BLOCK_SIZE": 128,
32+
"num_warps": 8
33+
},
34+
"8": {
35+
"BLOCK_SIZE": 128,
36+
"num_warps": 8
37+
}
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
{
2+
"1": {
3+
"BLOCK_DIM": 256,
4+
"BLOCK_M": 2,
5+
"NUM_STAGE": 2,
6+
"num_warps": 8
7+
},
8+
"100": {
9+
"BLOCK_DIM": 1024,
10+
"BLOCK_M": 1,
11+
"NUM_STAGE": 1,
12+
"num_warps": 8
13+
},
14+
"1024": {
15+
"BLOCK_DIM": 1024,
16+
"BLOCK_M": 1,
17+
"NUM_STAGE": 4,
18+
"num_warps": 1
19+
},
20+
"128": {
21+
"BLOCK_DIM": 1024,
22+
"BLOCK_M": 1,
23+
"NUM_STAGE": 1,
24+
"num_warps": 16
25+
},
26+
"16": {
27+
"BLOCK_DIM": 128,
28+
"BLOCK_M": 1,
29+
"NUM_STAGE": 1,
30+
"num_warps": 2
31+
},
32+
"256": {
33+
"BLOCK_DIM": 1024,
34+
"BLOCK_M": 1,
35+
"NUM_STAGE": 4,
36+
"num_warps": 2
37+
},
38+
"32": {
39+
"BLOCK_DIM": 128,
40+
"BLOCK_M": 1,
41+
"NUM_STAGE": 4,
42+
"num_warps": 4
43+
},
44+
"64": {
45+
"BLOCK_DIM": 128,
46+
"BLOCK_M": 1,
47+
"NUM_STAGE": 4,
48+
"num_warps": 4
49+
},
50+
"8": {
51+
"BLOCK_DIM": 1024,
52+
"BLOCK_M": 1,
53+
"NUM_STAGE": 1,
54+
"num_warps": 16
55+
}
56+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
{
2+
"1024": {
3+
"BLOCK_M": 1,
4+
"BLOCK_N": 256,
5+
"NUM_STAGES": 2,
6+
"num_warps": 4
7+
},
8+
"128": {
9+
"BLOCK_M": 1,
10+
"BLOCK_N": 256,
11+
"NUM_STAGES": 1,
12+
"num_warps": 8
13+
},
14+
"2048": {
15+
"BLOCK_M": 1,
16+
"BLOCK_N": 256,
17+
"NUM_STAGES": 1,
18+
"num_warps": 1
19+
},
20+
"256": {
21+
"BLOCK_M": 1,
22+
"BLOCK_N": 256,
23+
"NUM_STAGES": 1,
24+
"num_warps": 8
25+
},
26+
"512": {
27+
"BLOCK_M": 1,
28+
"BLOCK_N": 128,
29+
"NUM_STAGES": 2,
30+
"num_warps": 4
31+
},
32+
"64": {
33+
"BLOCK_M": 1,
34+
"BLOCK_N": 64,
35+
"NUM_STAGES": 4,
36+
"num_warps": 1
37+
},
38+
"8": {
39+
"BLOCK_M": 1,
40+
"BLOCK_N": 64,
41+
"NUM_STAGES": 4,
42+
"num_warps": 1
43+
},
44+
"800": {
45+
"BLOCK_M": 1,
46+
"BLOCK_N": 256,
47+
"NUM_STAGES": 2,
48+
"num_warps": 1
49+
},
50+
"8192": {
51+
"BLOCK_M": 8,
52+
"BLOCK_N": 256,
53+
"NUM_STAGES": 4,
54+
"num_warps": 1
55+
}
56+
}

0 commit comments

Comments
 (0)