Skip to content

Commit 14d4618

Browse files
authored
[Loader] add multi-thread model loading (#6877)
* multi-thread-loader * fix ut
1 parent c1fb311 commit 14d4618

12 files changed

Lines changed: 105 additions & 7 deletions

File tree

docs/parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ When using FastDeploy to deploy models (including offline inference and service
5858
| ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. |
5959
| ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. |
6060
| ```load_choices``` | `str` | Weight loader selection, default: "default_v1". Supports "default", "default_v1", and "dummy". "default_v1" is used for loading torch weights and weight acceleration. "dummy" is used for quickly and randomly initializes weights for testing|
61+
| ```model_loader_extra_config``` | `dict[str]` | Additional configuration options for the model loader. Supports: <br> - `enable_multithread_load` (bool): Enable multi-threaded weight loading. <br> - `num_threads` (int): Number of threads for loading. Defaults to 8. <br> - `disable_mmap` (bool): Disable memory-mapped file access. Useful when mmap is not supported. <br> Example: `'{"enable_multithread_load": true, "num_threads": 8}'` |
6162
| ```max_encoder_cache``` | `int` | Maximum number of tokens in the encoder cache (use 0 to disable), default: -1 (auto-calculated) |
6263
| ```max_processor_cache``` | `float` | Maximum number of bytes(in GiB) in the processor cache (use 0 to disable), default: -1 (auto-calculated) |
6364
| ```api_key``` |`list[str]`| Validate API keys in the service request headers, supporting multiple key inputs. Same effect as environment variable `FD_API_KEY`, with higher priority|

docs/zh/parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
| ```tool_call_parser``` | `str` | 指定要使用的function call解析器,以便从模型输出中抽取 function call内容|
5757
| ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径,以便注册不在代码库中的parser,parser中代码格式需遵循代码库中格式|
5858
| ```load_choices``` | `str` | 权重加载器选择,默认使用"default_v1"。支持"default"和"default_v1",后者用于加载torch权重和权重加速|
59+
| ```model_loader_extra_config``` | `dict[str]` | 模型加载器额外配置选项。支持:<br> - `enable_multithread_load` (bool): 启用多线程权重加载。<br> - `num_threads` (int): 加载线程数,默认为8。<br> - `disable_mmap` (bool): 禁用内存映射文件访问,当mmap不支持时使用。<br> 示例:`'{"enable_multithread_load": true, "num_threads": 8}'` |
5960
| ```max_encoder_cache``` | `int` | 编码器缓存的最大token数(使用0表示禁用),默认-1(自动计算)|
6061
| ```max_processor_cache``` | `float` | 处理器缓存的最大字节数(以GiB为单位,使用0表示禁用),默认-1(自动计算)|
6162
| ```api_key``` |`list[str]`| 校验服务请求头中的API密钥,支持传入多个密钥;与环境变量`FD_API_KEY`中的值效果相同,且优先级高于环境变量配置|

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,7 @@ def __init__(
14471447
self.dynamic_load_weight: bool = False
14481448
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal", "rsync"]] = "normal"
14491449
self.rsync_config: Optional[Dict[str, Any]] = None
1450+
self.model_loader_extra_config: Optional[Dict[str, Any]] = None
14501451
for key, value in args.items():
14511452
if hasattr(self, key):
14521453
setattr(self, key, value)

fastdeploy/engine/args_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,14 @@ class EngineArgs:
496496
- "default": default loader.
497497
- "default_v1": default_v1 loader.
498498
"""
499+
model_loader_extra_config: Optional[Dict[str, Any]] = None
500+
"""
501+
Additional configuration options for the model loader.
502+
Supports:
503+
- enable_multithread_load (bool): Enable multi-threaded weight loading.
504+
- num_threads (int): Number of threads for loading. Defaults to 8.
505+
- disable_mmap (bool): Disable memory-mapped file access.
506+
"""
499507

500508
lm_head_fp32: bool = False
501509
"""
@@ -1091,6 +1099,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10911099
default/default_v1/dummy.",
10921100
)
10931101

1102+
load_group.add_argument(
1103+
"--model-loader-extra-config",
1104+
type=json.loads,
1105+
default=EngineArgs.model_loader_extra_config,
1106+
help="Additional configuration for model loader (JSON format). "
1107+
'e.g., \'{"enable_multithread_load": true, "num_threads": 8}\'',
1108+
)
1109+
10941110
# CacheConfig parameters group
10951111
cache_group = parser.add_argument_group("Cache Configuration")
10961112

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,6 +2483,7 @@ def _start_worker_service(self):
24832483
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
24842484
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
24852485
f" --load_choices {self.cfg.load_config.load_choices}"
2486+
f" --model_loader_extra_config '{json.dumps(self.cfg.load_config.model_loader_extra_config)}'"
24862487
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
24872488
f" --ips {ips}"
24882489
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,7 @@ def _start_worker_service(self):
613613
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
614614
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
615615
f" --load_choices {self.cfg.load_config.load_choices}"
616+
f" --model_loader_extra_config '{json.dumps(self.cfg.load_config.model_loader_extra_config)}'"
616617
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
617618
f" --ips {ips}"
618619
f" --max_encoder_cache {self.cfg.cache_config.max_encoder_cache}"

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17+
import concurrent.futures
1718
import contextlib
1819
import copy
1920
import hashlib
@@ -26,8 +27,11 @@
2627
from contextlib import ExitStack
2728
from functools import wraps
2829
from pathlib import Path
30+
from typing import Optional
2931

3032
import paddle
33+
import paddle.distributed as dist
34+
import safetensors
3135
from paddleformers.transformers import PretrainedModel
3236
from paddleformers.transformers.model_utils import load_tp_checkpoint
3337
from paddleformers.utils.log import logger
@@ -36,10 +40,12 @@
3640
from tqdm import tqdm
3741

3842
from fastdeploy import envs
39-
from fastdeploy.config import FDConfig
43+
from fastdeploy.config import FDConfig, LoadConfig
4044
from fastdeploy.model_executor.layers.linear import KVBatchLinear
4145
from fastdeploy.model_executor.utils import multi_switch_config_context
4246

47+
DEFAULT_NUM_THREADS = 8
48+
4349

4450
def natural_key(s: str):
4551
return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]
@@ -111,13 +117,21 @@ def get_model_path(fd_config: FDConfig):
111117
return model_path
112118

113119

114-
def get_weight_iterator(model_path: str):
120+
def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None):
115121
files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path)
116122
if use_safetensors:
117-
if is_layers_are_grouped:
118-
weights_iterator = safetensors_weights_iterator(files_list)
123+
extra_config = load_config.model_loader_extra_config if load_config else None
124+
if extra_config is not None and extra_config.get("enable_multithread_load", False):
125+
weights_iterator = multi_thread_safetensors_weights_iterator(
126+
files_list,
127+
max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS),
128+
disable_mmap=extra_config.get("disable_mmap", False),
129+
)
119130
else:
120-
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
131+
if is_layers_are_grouped:
132+
weights_iterator = safetensors_weights_iterator(files_list)
133+
else:
134+
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
121135
else:
122136
weights_iterator = pdparams_weight_iterator(files_list)
123137

@@ -401,6 +415,52 @@ def safetensors_weights_iterator(safe_tensor_list: list[str]):
401415
yield name, param
402416

403417

418+
def multi_thread_safetensors_weights_iterator(safe_tensor_list, max_workers: int = 4, disable_mmap: bool = False):
419+
"""
420+
Iterate over safetensors weights using multi-threaded loading.
421+
422+
Args:
423+
safe_tensor_list: List of safetensors file paths to load.
424+
max_workers: Maximum number of threads for concurrent loading. Defaults to 4.
425+
disable_mmap: If True, load files into memory directly instead of using memory-mapped
426+
files. Useful when mmap is not supported or causes issues.
427+
428+
Yields:
429+
Tuple[str, paddle.Tensor]: Weight name and corresponding tensor.
430+
"""
431+
try:
432+
enable_tqdm = dist.get_rank() == 0
433+
except Exception:
434+
enable_tqdm = True
435+
436+
def _load_file(st_file: str):
437+
if disable_mmap:
438+
with open(st_file, "rb") as f:
439+
result = safetensors.paddle.load(f.read())
440+
else:
441+
result = safetensors.paddle.load_file(st_file, device="cpu")
442+
443+
return result
444+
445+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
446+
futures = [executor.submit(_load_file, st_file) for st_file in safe_tensor_list]
447+
448+
if enable_tqdm:
449+
futures_iter = tqdm(
450+
concurrent.futures.as_completed(futures),
451+
total=len(safe_tensor_list),
452+
desc="Multi-thread loading shards",
453+
disable=not enable_tqdm,
454+
)
455+
else:
456+
futures_iter = concurrent.futures.as_completed(futures)
457+
458+
for future in futures_iter:
459+
state_dict = future.result()
460+
for name, param in state_dict.items():
461+
yield name, param
462+
463+
404464
def safetensors_weights_iterator_ordered(ordered_weight_map: dict[str, str]):
405465
"""
406466
safetensors_weights_iterator_ordered

fastdeploy/model_executor/model_loader/default_loader_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def clean_memory_fragments(self) -> None:
5757
@measure_time()
5858
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
5959
model_path = get_model_path(fd_config)
60-
weights_iterator = get_weight_iterator(model_path)
60+
weights_iterator = get_weight_iterator(model_path, fd_config.load_config)
6161
if enable_cache:
6262
load_weights_from_cache(model, weights_iterator)
6363
else:

fastdeploy/worker/worker_process.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,14 @@ def parse_args():
10281028
help="The format of the model weights to load. default/default_v1/dummy.",
10291029
)
10301030

1031+
parser.add_argument(
1032+
"--model_loader_extra_config",
1033+
type=json.loads,
1034+
default=None,
1035+
help="Additional configuration for model loader (JSON format). "
1036+
'e.g., \'{"enable_multithread_load": true, "num_threads": 8}\'',
1037+
)
1038+
10311039
parser.add_argument(
10321040
"--ips",
10331041
type=str,

tests/engine/test_engine.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ def _make_cfg(**ov):
4545
cc.enable_prefix_caching = cc.enable_chunked_prefill = False
4646
cc.kv_cache_ratio, cc.kvcache_storage_backend, cc.num_cpu_blocks, cc.max_encoder_cache = 1.0, None, 0, 0
4747
cc.cache_transfer_protocol, cc.total_block_num = "tcp", 100
48-
lc = ns(load_strategy="auto", rsync_config={}, dynamic_load_weight=False, load_choices="auto")
48+
lc = ns(
49+
load_strategy="auto",
50+
rsync_config={},
51+
dynamic_load_weight=False,
52+
load_choices="auto",
53+
model_loader_extra_config={},
54+
)
4955
soc = ns(guided_decoding_backend=None, logits_processors=None, reasoning_parser="none")
5056
soc.disable_any_whitespace = False
5157
cfg = ns(model_config=mc, parallel_config=pc, scheduler_config=sc, cache_config=cc, load_config=lc)

0 commit comments

Comments
 (0)