Skip to content

Commit b048de9

Browse files
author
Developer
committed
clean code
1 parent a00abc2 commit b048de9

File tree

22 files changed

+405
-105
lines changed

22 files changed

+405
-105
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
1313
from lightllm.common.basemodel.infer_struct import InferStateInfo
14-
from lightllm.common.basemodel.routing_manager import reset_moe_layer_counter
1514
from lightllm.common.kv_cache_mem_manager import MemoryManager
1615
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
1716
from lightllm.common.req_manager import ReqManager
@@ -165,7 +164,6 @@ def _init_quant(self):
165164
logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}")
166165

167166
def _init_weights(self, start_layer_index=0):
168-
reset_moe_layer_counter()
169167
self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config)
170168
self.trans_layers_weight = [
171169
self.transformer_weight_class(

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args
1414
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank
1515
from lightllm.utils.log_utils import init_logger
16-
from lightllm.common.basemodel.routing_manager import get_next_moe_layer_index
1716

1817
logger = init_logger(__name__)
1918

@@ -34,9 +33,9 @@ def __init__(
3433
num_fused_shared_experts: int = 0,
3534
layer_num: int = 0,
3635
network_config: Dict[str, Any] = None,
36+
moe_layer_index: int = 0,
3737
) -> None:
3838
super().__init__(data_type=data_type)
39-
self.moe_layer_index = get_next_moe_layer_index()
4039
self.w1_weight_name = gate_proj_name
4140
self.w2_weight_name = down_proj_name
4241
self.w3_weight_name = up_proj_name
@@ -52,6 +51,7 @@ def __init__(
5251
self.enable_ep_moe = get_env_start_args().enable_ep_moe
5352
self.n_routed_experts = n_routed_experts
5453
self.num_fused_shared_experts = num_fused_shared_experts
54+
self.moe_layer_index = moe_layer_index
5555
self._init_config(network_config)
5656
self._init_redundancy_expert_params()
5757
self._init_parallel_params()

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
num_fused_shared_experts: int = 0,
4848
layer_num: int = 0,
4949
network_config: Dict[str, Any] = None,
50+
moe_layer_index: int = 0,
5051
) -> None:
5152
network_config["norm_topk_prob"] = None
5253
super().__init__(
@@ -63,6 +64,7 @@ def __init__(
6364
num_fused_shared_experts=num_fused_shared_experts,
6465
layer_num=layer_num,
6566
network_config=network_config,
67+
moe_layer_index=moe_layer_index,
6668
)
6769

6870
self.hidden_size = network_config["hidden_size"]
@@ -150,6 +152,7 @@ def experts(
150152

151153
topk_weights, topk_ids = self._router(router_logits, top_k)
152154

155+
# Rollout router replay
153156
if g_routing_capture_manager is not None:
154157
g_routing_capture_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)
155158

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .mm_weight import (
22
MMWeightTpl,
33
)
4-
from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight
4+
from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight, QKVROWNMMWeight
55
from .colmm_weight import COLMMWeight

lightllm/common/basemodel/routing_manager.py

Lines changed: 9 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,61 +9,10 @@
99
logger = init_logger(__name__)
1010

1111

12-
class SharedRoutingConfig:
13-
"""Shared MoE routing configuration across processes."""
14-
15-
def __init__(self):
16-
service_name = get_unique_server_name()
17-
self._shm = SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32)
18-
19-
@property
20-
def num_moe_layers(self) -> int:
21-
return int(self._shm.arr[0])
22-
23-
@num_moe_layers.setter
24-
def num_moe_layers(self, value: int):
25-
self._shm.arr[0] = value
26-
27-
@property
28-
def topk(self) -> int:
29-
return int(self._shm.arr[1])
30-
31-
@topk.setter
32-
def topk(self, value: int):
33-
self._shm.arr[1] = value
34-
35-
def is_initialized(self) -> bool:
36-
return self.num_moe_layers > 0 and self.topk > 0
37-
38-
39-
_shared_routing_config: Optional[SharedRoutingConfig] = None
40-
41-
42-
def get_shared_routing_config() -> SharedRoutingConfig:
43-
"""Get or create the shared routing config."""
44-
global _shared_routing_config
45-
if _shared_routing_config is None:
46-
_shared_routing_config = SharedRoutingConfig()
47-
return _shared_routing_config
48-
49-
50-
_moe_layer_counter: int = 0
51-
52-
53-
def reset_moe_layer_counter() -> None:
54-
global _moe_layer_counter
55-
_moe_layer_counter = 0
56-
57-
58-
def get_next_moe_layer_index() -> int:
59-
global _moe_layer_counter
60-
idx = _moe_layer_counter
61-
_moe_layer_counter += 1
62-
return idx
63-
64-
65-
def get_moe_layer_count() -> int:
66-
return _moe_layer_counter
12+
def get_routing_config_shm() -> SharedArray:
13+
"""Get shared memory for MoE routing config: arr[0]=num_moe_layers, arr[1]=topk."""
14+
service_name = get_unique_server_name()
15+
return SharedArray(f"{service_name}_routing_config", shape=(2,), dtype=np.int32)
6716

6817

6918
class RoutingCaptureManager:
@@ -170,15 +119,11 @@ def create_routing_capture_manager(
170119
)
171120

172121

173-
def init_routing_capture(model) -> None:
174-
if not getattr(model.args, "enable_return_routed_experts", False):
175-
return
176-
122+
def init_routing_capture(model, num_moe_layers: int) -> None:
177123
if get_current_rank_in_dp() != 0:
178-
logger.info("Skipping routing capture initialization on non-zero rank")
124+
# Skipping routing capture initialization on non-zero rank
179125
return
180126

181-
num_moe_layers = get_moe_layer_count()
182127
if num_moe_layers == 0:
183128
logger.warning(
184129
"enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled."
@@ -204,12 +149,7 @@ def init_routing_capture(model) -> None:
204149
enable_overlap=enable_overlap,
205150
)
206151

207-
shared_config = get_shared_routing_config()
208-
shared_config.num_moe_layers = num_moe_layers
209-
shared_config.topk = topk
152+
shm = get_routing_config_shm()
153+
shm.arr[0] = num_moe_layers
154+
shm.arr[1] = topk
210155
logger.info(f"Shared routing config set: num_moe_layers={num_moe_layers}, topk={topk}")
211-
212-
213-
def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None:
214-
if g_routing_capture_manager is not None:
215-
g_routing_capture_manager.flush_to_cpu_async(mem_indexes, microbatch_index)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{
2+
"1024": {
3+
"BLOCK_SIZE_K": 64,
4+
"BLOCK_SIZE_M": 16,
5+
"BLOCK_SIZE_N": 128,
6+
"GROUP_SIZE_M": 64,
7+
"NEED_TRANS": false,
8+
"num_stages": 2,
9+
"num_warps": 4
10+
},
11+
"128": {
12+
"BLOCK_SIZE_K": 64,
13+
"BLOCK_SIZE_M": 16,
14+
"BLOCK_SIZE_N": 128,
15+
"GROUP_SIZE_M": 16,
16+
"NEED_TRANS": false,
17+
"num_stages": 3,
18+
"num_warps": 4
19+
},
20+
"2048": {
21+
"BLOCK_SIZE_K": 32,
22+
"BLOCK_SIZE_M": 32,
23+
"BLOCK_SIZE_N": 128,
24+
"GROUP_SIZE_M": 16,
25+
"NEED_TRANS": false,
26+
"num_stages": 3,
27+
"num_warps": 4
28+
},
29+
"256": {
30+
"BLOCK_SIZE_K": 64,
31+
"BLOCK_SIZE_M": 16,
32+
"BLOCK_SIZE_N": 128,
33+
"GROUP_SIZE_M": 1,
34+
"NEED_TRANS": false,
35+
"num_stages": 2,
36+
"num_warps": 4
37+
},
38+
"512": {
39+
"BLOCK_SIZE_K": 64,
40+
"BLOCK_SIZE_M": 16,
41+
"BLOCK_SIZE_N": 128,
42+
"GROUP_SIZE_M": 1,
43+
"NEED_TRANS": false,
44+
"num_stages": 4,
45+
"num_warps": 4
46+
},
47+
"64": {
48+
"BLOCK_SIZE_K": 64,
49+
"BLOCK_SIZE_M": 16,
50+
"BLOCK_SIZE_N": 128,
51+
"GROUP_SIZE_M": 1,
52+
"NEED_TRANS": false,
53+
"num_stages": 2,
54+
"num_warps": 4
55+
},
56+
"8": {
57+
"BLOCK_SIZE_K": 32,
58+
"BLOCK_SIZE_M": 16,
59+
"BLOCK_SIZE_N": 64,
60+
"GROUP_SIZE_M": 1,
61+
"NEED_TRANS": false,
62+
"num_stages": 2,
63+
"num_warps": 4
64+
},
65+
"800": {
66+
"BLOCK_SIZE_K": 64,
67+
"BLOCK_SIZE_M": 16,
68+
"BLOCK_SIZE_N": 128,
69+
"GROUP_SIZE_M": 32,
70+
"NEED_TRANS": false,
71+
"num_stages": 2,
72+
"num_warps": 4
73+
},
74+
"8192": {
75+
"BLOCK_SIZE_K": 64,
76+
"BLOCK_SIZE_M": 64,
77+
"BLOCK_SIZE_N": 128,
78+
"GROUP_SIZE_M": 32,
79+
"NEED_TRANS": false,
80+
"num_stages": 2,
81+
"num_warps": 4
82+
}
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_K": 128,
4+
"BLOCK_SIZE_M": 16,
5+
"BLOCK_SIZE_N": 64,
6+
"GROUP_SIZE_M": 1,
7+
"NEED_TRANS": false,
8+
"num_stages": 4,
9+
"num_warps": 4
10+
},
11+
"100": {
12+
"BLOCK_SIZE_K": 128,
13+
"BLOCK_SIZE_M": 16,
14+
"BLOCK_SIZE_N": 128,
15+
"GROUP_SIZE_M": 1,
16+
"NEED_TRANS": false,
17+
"num_stages": 3,
18+
"num_warps": 4
19+
},
20+
"1024": {
21+
"BLOCK_SIZE_K": 32,
22+
"BLOCK_SIZE_M": 64,
23+
"BLOCK_SIZE_N": 128,
24+
"GROUP_SIZE_M": 64,
25+
"NEED_TRANS": false,
26+
"num_stages": 3,
27+
"num_warps": 4
28+
},
29+
"128": {
30+
"BLOCK_SIZE_K": 128,
31+
"BLOCK_SIZE_M": 16,
32+
"BLOCK_SIZE_N": 128,
33+
"GROUP_SIZE_M": 32,
34+
"NEED_TRANS": false,
35+
"num_stages": 2,
36+
"num_warps": 8
37+
},
38+
"16": {
39+
"BLOCK_SIZE_K": 64,
40+
"BLOCK_SIZE_M": 16,
41+
"BLOCK_SIZE_N": 128,
42+
"GROUP_SIZE_M": 1,
43+
"NEED_TRANS": false,
44+
"num_stages": 3,
45+
"num_warps": 4
46+
},
47+
"256": {
48+
"BLOCK_SIZE_K": 128,
49+
"BLOCK_SIZE_M": 32,
50+
"BLOCK_SIZE_N": 128,
51+
"GROUP_SIZE_M": 16,
52+
"NEED_TRANS": false,
53+
"num_stages": 2,
54+
"num_warps": 4
55+
},
56+
"32": {
57+
"BLOCK_SIZE_K": 128,
58+
"BLOCK_SIZE_M": 16,
59+
"BLOCK_SIZE_N": 64,
60+
"GROUP_SIZE_M": 16,
61+
"NEED_TRANS": false,
62+
"num_stages": 3,
63+
"num_warps": 4
64+
},
65+
"64": {
66+
"BLOCK_SIZE_K": 128,
67+
"BLOCK_SIZE_M": 16,
68+
"BLOCK_SIZE_N": 128,
69+
"GROUP_SIZE_M": 32,
70+
"NEED_TRANS": false,
71+
"num_stages": 2,
72+
"num_warps": 4
73+
},
74+
"8": {
75+
"BLOCK_SIZE_K": 128,
76+
"BLOCK_SIZE_M": 16,
77+
"BLOCK_SIZE_N": 128,
78+
"GROUP_SIZE_M": 32,
79+
"NEED_TRANS": false,
80+
"num_stages": 3,
81+
"num_warps": 8
82+
}
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE": 256,
4+
"num_warps": 4
5+
},
6+
"100": {
7+
"BLOCK_SIZE": 128,
8+
"num_warps": 8
9+
},
10+
"1024": {
11+
"BLOCK_SIZE": 256,
12+
"num_warps": 4
13+
},
14+
"128": {
15+
"BLOCK_SIZE": 256,
16+
"num_warps": 8
17+
},
18+
"16": {
19+
"BLOCK_SIZE": 128,
20+
"num_warps": 8
21+
},
22+
"256": {
23+
"BLOCK_SIZE": 128,
24+
"num_warps": 8
25+
},
26+
"32": {
27+
"BLOCK_SIZE": 128,
28+
"num_warps": 8
29+
},
30+
"64": {
31+
"BLOCK_SIZE": 128,
32+
"num_warps": 8
33+
},
34+
"8": {
35+
"BLOCK_SIZE": 128,
36+
"num_warps": 8
37+
}
38+
}

0 commit comments

Comments
 (0)