|
14 | 14 | # limitations under the License. |
15 | 15 | """ |
16 | 16 |
|
| 17 | +import concurrent.futures |
17 | 18 | import contextlib |
18 | 19 | import copy |
19 | 20 | import hashlib |
|
26 | 27 | from contextlib import ExitStack |
27 | 28 | from functools import wraps |
28 | 29 | from pathlib import Path |
| 30 | +from typing import Optional |
29 | 31 |
|
30 | 32 | import paddle |
| 33 | +import paddle.distributed as dist |
| 34 | +import safetensors |
31 | 35 | from paddleformers.transformers import PretrainedModel |
32 | 36 | from paddleformers.transformers.model_utils import load_tp_checkpoint |
33 | 37 | from paddleformers.utils.log import logger |
|
36 | 40 | from tqdm import tqdm |
37 | 41 |
|
38 | 42 | from fastdeploy import envs |
39 | | -from fastdeploy.config import FDConfig |
| 43 | +from fastdeploy.config import FDConfig, LoadConfig |
40 | 44 | from fastdeploy.model_executor.layers.linear import KVBatchLinear |
41 | 45 | from fastdeploy.model_executor.utils import multi_switch_config_context |
42 | 46 |
|
| 47 | +DEFAULT_NUM_THREADS = 8 |
| 48 | + |
43 | 49 |
|
44 | 50 | def natural_key(s: str): |
45 | 51 | return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)] |
@@ -111,13 +117,21 @@ def get_model_path(fd_config: FDConfig): |
111 | 117 | return model_path |
112 | 118 |
|
113 | 119 |
|
114 | | -def get_weight_iterator(model_path: str): |
| 120 | +def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None): |
115 | 121 | files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path) |
116 | 122 | if use_safetensors: |
117 | | - if is_layers_are_grouped: |
118 | | - weights_iterator = safetensors_weights_iterator(files_list) |
| 123 | + extra_config = load_config.model_loader_extra_config if load_config else None |
| 124 | + if extra_config is not None and extra_config.get("enable_multithread_load", False): |
| 125 | + weights_iterator = multi_thread_safetensors_weights_iterator( |
| 126 | + files_list, |
| 127 | + max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS), |
| 128 | + disable_mmap=extra_config.get("disable_mmap", False), |
| 129 | + ) |
119 | 130 | else: |
120 | | - weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map) |
| 131 | + if is_layers_are_grouped: |
| 132 | + weights_iterator = safetensors_weights_iterator(files_list) |
| 133 | + else: |
| 134 | + weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map) |
121 | 135 | else: |
122 | 136 | weights_iterator = pdparams_weight_iterator(files_list) |
123 | 137 |
|
@@ -401,6 +415,52 @@ def safetensors_weights_iterator(safe_tensor_list: list[str]): |
401 | 415 | yield name, param |
402 | 416 |
|
403 | 417 |
|
| 418 | +def multi_thread_safetensors_weights_iterator(safe_tensor_list, max_workers: int = 4, disable_mmap: bool = False): |
| 419 | + """ |
| 420 | + Iterate over safetensors weights using multi-threaded loading. |
| 421 | +
|
| 422 | + Args: |
| 423 | + safe_tensor_list: List of safetensors file paths to load. |
| 424 | + max_workers: Maximum number of threads for concurrent loading. Defaults to 4. |
| 425 | + disable_mmap: If True, load files into memory directly instead of using memory-mapped |
| 426 | + files. Useful when mmap is not supported or causes issues. |
| 427 | +
|
| 428 | + Yields: |
| 429 | + Tuple[str, paddle.Tensor]: Weight name and corresponding tensor. |
| 430 | + """ |
| 431 | + try: |
| 432 | + enable_tqdm = dist.get_rank() == 0 |
| 433 | + except Exception: |
| 434 | + enable_tqdm = True |
| 435 | + |
| 436 | + def _load_file(st_file: str): |
| 437 | + if disable_mmap: |
| 438 | + with open(st_file, "rb") as f: |
| 439 | + result = safetensors.paddle.load(f.read()) |
| 440 | + else: |
| 441 | + result = safetensors.paddle.load_file(st_file, device="cpu") |
| 442 | + |
| 443 | + return result |
| 444 | + |
| 445 | + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: |
| 446 | + futures = [executor.submit(_load_file, st_file) for st_file in safe_tensor_list] |
| 447 | + |
| 448 | + if enable_tqdm: |
| 449 | + futures_iter = tqdm( |
| 450 | + concurrent.futures.as_completed(futures), |
| 451 | + total=len(safe_tensor_list), |
| 452 | + desc="Multi-thread loading shards", |
| 453 | + disable=not enable_tqdm, |
| 454 | + ) |
| 455 | + else: |
| 456 | + futures_iter = concurrent.futures.as_completed(futures) |
| 457 | + |
| 458 | + for future in futures_iter: |
| 459 | + state_dict = future.result() |
| 460 | + for name, param in state_dict.items(): |
| 461 | + yield name, param |
| 462 | + |
| 463 | + |
404 | 464 | def safetensors_weights_iterator_ordered(ordered_weight_map: dict[str, str]): |
405 | 465 | """ |
406 | 466 | safetensors_weights_iterator_ordered |
|
0 commit comments