Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions fastdeploy/model_executor/load_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from tqdm import tqdm

from fastdeploy import envs
from fastdeploy.config import FDConfig, LoadConfig
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.utils import multi_switch_config_context

Expand Down Expand Up @@ -72,6 +72,11 @@ def layers_are_grouped(keys):
return True


def values_are_naturally_ordered(values):
"""Check if values are sorted in natural order."""
return list(values) == sorted(values, key=natural_key)

This comment was marked as outdated.

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 values_are_naturally_ordered 缺少单元测试覆盖

该函数是此 PR 核心修复逻辑的入口,但测试文件中未新增对应测试用例。现有 test_get_weight_iterator_ordered_and_kv_scale 的 weight_map 全部指向同一个 shard,无法覆盖「多 shard 文件中 values 无序」的触发场景。

建议补充:

  1. 单测 values_are_naturally_ordered(有序/无序各一个 case)
  2. test_get_weight_iterator_ordered_and_kv_scale 风格下添加多 shard、values 无序的集成场景(如 layer 0 → shard-2, layer 1 → shard-1)

此逻辑是 OOM 修复的关键路径,回归测试缺失时风险较高。



def pdparams_weight_iterator(paddle_file_list: list[str]):
for pdparams_file in tqdm(
paddle_file_list,
Expand Down Expand Up @@ -117,18 +122,20 @@ def get_model_path(fd_config: FDConfig):
return model_path


def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None):
def get_weight_iterator(model_path: str, fd_config: Optional[FDConfig] = None):
files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path)
if use_safetensors:
load_config = fd_config.load_config if fd_config else None
extra_config = load_config.model_loader_extra_config if load_config else None
parallel_config = fd_config.parallel_config if fd_config else None
if extra_config is not None and extra_config.get("enable_multithread_load", False):
weights_iterator = multi_thread_safetensors_weights_iterator(
files_list,
max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS),
disable_mmap=extra_config.get("disable_mmap", False),
)
else:
if is_layers_are_grouped:
if is_layers_are_grouped or (parallel_config is not None and parallel_config.tensor_parallel_size == 1):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 新增的 TP=1 短路条件缺少注释说明意图

tensor_parallel_size == 1 时,强制走 safetensors_weights_iterator(files_list)(按文件名顺序逐文件加载),绕过了 is_layers_are_grouped 的检查。这是合理的优化(单卡场景下权重加载顺序对正确性无影响,且可避免 ordered iterator 跨 shard 跳读引发的 OOM),但直接阅读时逻辑意图不明显。

建议加上注释说明:

if is_layers_are_grouped or (
    # For TP=1, sequential file-by-file loading is always safe and avoids
    # the OOM risk of safetensors_weights_iterator_ordered jumping between
    # shard files in non-sequential order.
    parallel_config is not None and parallel_config.tensor_parallel_size == 1
):

weights_iterator = safetensors_weights_iterator(files_list)
else:
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
Expand Down Expand Up @@ -532,7 +539,10 @@ def get_all_weights_file(model_path: str):
with index_file.open("r") as f:
weight_map = json.load(f)["weight_map"]
keys = list(weight_map.keys())
is_layers_are_grouped = layers_are_grouped(keys)
values = list(weight_map.values())
is_keys_orders = layers_are_grouped(keys)

This comment was marked as outdated.

This comment was marked as outdated.

is_values_naturally_ordered = values_are_naturally_ordered(values)
is_layers_are_grouped = is_keys_orders and is_values_naturally_ordered
ordered_weight_map = {
key: str(model_path / weight_map[key]) for key in sorted(weight_map.keys(), key=natural_key)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def clean_memory_fragments(self) -> None:
@measure_time()
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
model_path = get_model_path(fd_config)
weights_iterator = get_weight_iterator(model_path, fd_config.load_config)
weights_iterator = get_weight_iterator(model_path, fd_config)
if enable_cache:
load_weights_from_cache(model, weights_iterator)
else:
Expand Down
Loading