|
40 | 40 | from tqdm import tqdm |
41 | 41 |
|
42 | 42 | from fastdeploy import envs |
43 | | -from fastdeploy.config import FDConfig, LoadConfig |
| 43 | +from fastdeploy.config import FDConfig |
44 | 44 | from fastdeploy.model_executor.layers.linear import KVBatchLinear |
45 | 45 | from fastdeploy.model_executor.utils import multi_switch_config_context |
46 | 46 |
|
@@ -72,6 +72,11 @@ def layers_are_grouped(keys): |
72 | 72 | return True |
73 | 73 |
|
74 | 74 |
|
| 75 | +def values_are_naturally_ordered(values): |
| 76 | + """Check if values are sorted in natural order.""" |
| 77 | + return list(values) == sorted(values, key=natural_key) |
| 78 | + |
| 79 | + |
75 | 80 | def pdparams_weight_iterator(paddle_file_list: list[str]): |
76 | 81 | for pdparams_file in tqdm( |
77 | 82 | paddle_file_list, |
@@ -117,18 +122,20 @@ def get_model_path(fd_config: FDConfig): |
117 | 122 | return model_path |
118 | 123 |
|
119 | 124 |
|
120 | | -def get_weight_iterator(model_path: str, load_config: Optional[LoadConfig] = None): |
| 125 | +def get_weight_iterator(model_path: str, fd_config: Optional[FDConfig] = None): |
121 | 126 | files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path) |
122 | 127 | if use_safetensors: |
| 128 | + load_config = fd_config.load_config if fd_config else None |
123 | 129 | extra_config = load_config.model_loader_extra_config if load_config else None |
| 130 | + parallel_config = fd_config.parallel_config if fd_config else None |
124 | 131 | if extra_config is not None and extra_config.get("enable_multithread_load", False): |
125 | 132 | weights_iterator = multi_thread_safetensors_weights_iterator( |
126 | 133 | files_list, |
127 | 134 | max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS), |
128 | 135 | disable_mmap=extra_config.get("disable_mmap", False), |
129 | 136 | ) |
130 | 137 | else: |
131 | | - if is_layers_are_grouped: |
| 138 | + if is_layers_are_grouped or (parallel_config is not None and parallel_config.tensor_parallel_size == 1): |
132 | 139 | weights_iterator = safetensors_weights_iterator(files_list) |
133 | 140 | else: |
134 | 141 | weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map) |
@@ -532,7 +539,10 @@ def get_all_weights_file(model_path: str): |
532 | 539 | with index_file.open("r") as f: |
533 | 540 | weight_map = json.load(f)["weight_map"] |
534 | 541 | keys = list(weight_map.keys()) |
535 | | - is_layers_are_grouped = layers_are_grouped(keys) |
| 542 | + values = list(weight_map.values()) |
| 543 | + is_keys_orders = layers_are_grouped(keys) |
| 544 | + is_values_naturally_ordered = values_are_naturally_ordered(values) |
| 545 | + is_layers_are_grouped = is_keys_orders and is_values_naturally_ordered |
536 | 546 | ordered_weight_map = { |
537 | 547 | key: str(model_path / weight_map[key]) for key in sorted(weight_map.keys(), key=natural_key) |
538 | 548 | } |
|
0 commit comments