Skip to content

Commit 8ed5074

Browse files
authored
use rpyc in model_rpc.py (#1221)
1 parent f5ee4c3 commit 8ed5074

File tree

7 files changed

+80
-204
lines changed

7 files changed

+80
-204
lines changed

lightllm/models/bloom/layer_weights/transformer_layer_weight.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def _parse_config(self):
5656
self.n_embed = self.network_config_["n_embed"]
5757
self.n_head = self.network_config_["num_attention_heads"]
5858
self.n_inter = self.network_config_["n_embed"] * 4
59+
self.q_head_num_ = self.n_head
60+
self.k_head_num_ = self.n_head
61+
self.v_head_num_ = self.n_head
62+
self.o_head_num_ = self.n_head
5963
self.n_kv_head = self.network_config_["num_attention_heads"]
6064
self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head)
6165
# 计算生成alibi

lightllm/server/api_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
259259
)
260260
parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache")
261261

262-
parser.add_argument("--chunked_prefill_size", type=int, default=8192, help="chunked prefill size")
262+
parser.add_argument("--chunked_prefill_size", type=int, default=None, help="chunked prefill size")
263263
parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill")
264264
parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode")
265265
parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode")
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .sampling_params import SamplingParams
22
from .req import Req, FinishStatus
33
from .shm_req_manager import ShmReqManager
4-
from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray
54
from .start_args_type import StartArgs
65
from .atomic_lock import AtomicShmLock

lightllm/server/core/objs/rpc_shm.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

lightllm/server/router/manager.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,37 +115,28 @@ async def wait_to_model_ready(self):
115115
self.model_rpc_servers = []
116116
# 用于 kv move 管理进程 和 推理进程进行task信息的交互。
117117
self.info_queue: mp.Queue = mp.Queue()
118-
self.rpc_event = multiprocessing.Event()
119-
self.rpc_finished_event = multiprocessing.Event()
120118

121119
assert (self.world_size % self.nnodes) == 0
122120
node_world_size = self.world_size // self.nnodes
123121

124122
# Create tasks for parallel startup
125123
tasks = []
126124
for rank_id in range(self.node_rank * node_world_size, (self.node_rank + 1) * node_world_size):
125+
rank_in_node = rank_id % node_world_size
127126
task = asyncio.create_task(
128127
start_model_process(
129128
args=self.args,
130129
rank=rank_id,
131-
rank_in_node=rank_id % node_world_size,
130+
rank_in_node=rank_in_node,
132131
node_world_size=node_world_size,
133-
rpc_event=self.rpc_event,
134-
rpc_finished_event=self.rpc_finished_event,
135132
info_queue=self.info_queue,
136133
router_lock=self.router_lock,
137134
)
138135
)
139136
tasks.append(task)
140137

141138
# Wait for all tasks to complete in parallel
142-
self.model_rpc_servers = await asyncio.gather(*tasks)
143-
144-
self.model_rpc_client = ModelRpcClient(
145-
rpc_event=self.rpc_event,
146-
rpc_finished_event=self.rpc_finished_event,
147-
)
148-
139+
self.model_rpc_clients = await asyncio.gather(*tasks)
149140
kvargs = {
150141
"args": self.args,
151142
"rank_id": None, # 由后续处理填充真实数据
@@ -178,10 +169,19 @@ async def wait_to_model_ready(self):
178169
"pd_rpyc_ports": self.args.pd_node_infer_rpyc_ports, # 非 pd 模式可以不设置
179170
}
180171

181-
await self.model_rpc_client.init_model(kvargs=kvargs)
172+
# Call init_model on all model processes
173+
init_tasks = []
174+
for model_rpc_client in self.model_rpc_clients:
175+
init_tasks.append(model_rpc_client.init_model(kvargs=kvargs))
176+
await asyncio.gather(*init_tasks)
182177

183178
if self.max_total_token_num is None:
184-
self.max_total_token_num = await self.model_rpc_client.get_max_total_token_num()
179+
_tasks = []
180+
for model_rpc_client in self.model_rpc_clients:
181+
_tasks.append(model_rpc_client.get_max_total_token_num())
182+
_nums = await asyncio.gather(*_tasks)
183+
assert max(_nums) == min(_nums), "all rank must have same token num"
184+
self.max_total_token_num = _nums[0]
185185
self.args.max_total_token_num = self.max_total_token_num
186186
if not self.args.disable_dynamic_prompt_cache:
187187
self.radix_cache_client = RadixCacheReadOnlyClient(

0 commit comments

Comments
 (0)