Skip to content

Commit 22da62c

Browse files
committed
Add values natural order check to layers grouped validation
1 parent 7bc29b5 commit 22da62c

2 files changed

Lines changed: 15 additions & 5 deletions

File tree

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from tqdm import tqdm
4141

4242
from fastdeploy import envs
43-
from fastdeploy.config import FDConfig, LoadConfig
43+
from fastdeploy.config import FDConfig
4444
from fastdeploy.model_executor.layers.linear import KVBatchLinear
4545
from fastdeploy.model_executor.utils import multi_switch_config_context
4646

@@ -72,6 +72,11 @@ def layers_are_grouped(keys):
7272
return True
7373

7474

75+
def values_are_naturally_ordered(values):
76+
"""Check if values are sorted in natural order."""
77+
return list(values) == sorted(values, key=natural_key)
78+
79+
7580
def pdparams_weight_iterator(paddle_file_list: list[str]):
7681
for pdparams_file in tqdm(
7782
paddle_file_list,
@@ -117,18 +122,20 @@ def get_model_path(fd_config: FDConfig):
117122
return model_path
118123

119124

120-
def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None):
125+
def get_weight_iterator(model_path: str, fd_config: Optional[FDConfig] = None):
121126
files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path)
122127
if use_safetensors:
128+
load_config = fd_config.load_config if fd_config else None
123129
extra_config = load_config.model_loader_extra_config if load_config else None
130+
parallel_config = fd_config.parallel_config if fd_config else None
124131
if extra_config is not None and extra_config.get("enable_multithread_load", False):
125132
weights_iterator = multi_thread_safetensors_weights_iterator(
126133
files_list,
127134
max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS),
128135
disable_mmap=extra_config.get("disable_mmap", False),
129136
)
130137
else:
131-
if is_layers_are_grouped:
138+
if is_layers_are_grouped or parallel_config.tensor_parallel_size == 1:
132139
weights_iterator = safetensors_weights_iterator(files_list)
133140
else:
134141
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
@@ -532,7 +539,10 @@ def get_all_weights_file(model_path: str):
532539
with index_file.open("r") as f:
533540
weight_map = json.load(f)["weight_map"]
534541
keys = list(weight_map.keys())
535-
is_layers_are_grouped = layers_are_grouped(keys)
542+
values = list(weight_map.values())
543+
is_keys_orders = layers_are_grouped(keys)
544+
is_values_naturally_ordered = values_are_naturally_ordered(values)
545+
is_layers_are_grouped = is_keys_orders and is_values_naturally_ordered
536546
ordered_weight_map = {
537547
key: str(model_path / weight_map[key]) for key in sorted(weight_map.keys(), key=natural_key)
538548
}

fastdeploy/model_executor/model_loader/default_loader_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def clean_memory_fragments(self) -> None:
5757
@measure_time()
5858
def load_weights(self, model, fd_config: FDConfig, enable_cache: bool = False) -> None:
5959
model_path = get_model_path(fd_config)
60-
weights_iterator = get_weight_iterator(model_path, fd_config.load_config)
60+
weights_iterator = get_weight_iterator(model_path, fd_config)
6161
if enable_cache:
6262
load_weights_from_cache(model, weights_iterator)
6363
else:

0 commit comments

Comments
 (0)