forked from PaddlePaddle/FastDeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathload_weight_utils.py
More file actions
634 lines (537 loc) · 25.5 KB
/
load_weight_utils.py
File metadata and controls
634 lines (537 loc) · 25.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import concurrent.futures
import contextlib
import copy
import hashlib
import inspect
import json
import os
import pickle
import re
import time
from contextlib import ExitStack
from functools import wraps
from pathlib import Path
from typing import Optional
import paddle
import paddle.distributed as dist
import safetensors
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.model_utils import load_tp_checkpoint
from paddleformers.utils.log import logger
from paddleformers.utils.safetensors import fast_safe_open
from safetensors import safe_open
from tqdm import tqdm
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.utils import multi_switch_config_context
DEFAULT_NUM_THREADS = 8
def natural_key(s: str):
return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]
def layers_are_grouped(keys):
seen = set()
current_layer = None
for k in keys:
m = re.search(r"layers\.(\d+)", k)
if not m:
continue
layer = int(m.group(1))
if layer != current_layer:
if layer in seen:
return False
seen.add(layer)
current_layer = layer
return True
def values_are_naturally_ordered(values):
"""Check if values are sorted in natural order."""
return list(values) == sorted(values, key=natural_key)
def pdparams_weight_iterator(paddle_file_list: list[str]):
for pdparams_file in tqdm(
paddle_file_list,
desc="Loading pdparams checkpoint shards",
):
state_dict = paddle.load(pdparams_file)
yield from state_dict.items()
del state_dict
def load_weights_from_cache(model, weights_iterator):
params_dict = dict(model.named_parameters())
for loaded_weight_name, loaded_weight in weights_iterator:
if loaded_weight_name not in params_dict:
logger.info(f"{loaded_weight_name} is not in model parameters.")
continue
param = params_dict[loaded_weight_name]
if param.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch between loaded weight {loaded_weight_name}: {loaded_weight.shape}, expected shape: {param.shape}"
)
param.copy_(loaded_weight, False)
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
model.lm_head.linear.weight.set_value(
loaded_weight.transpose([1, 0]).astype(model.lm_head.linear.weight.dtype)
)
for _, model_sublayer in model.named_sublayers():
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()
def get_model_path(fd_config: FDConfig):
model_path = fd_config.model_config.model
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
local_rank = fd_config.parallel_config.tensor_parallel_rank
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
model_path = os.path.join(model_path, f"rank{local_rank}")
fd_config.load_config.is_pre_sharded = True
return model_path
def get_weight_iterator(model_path: str, fd_config: Optional[FDConfig] = None):
files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped = get_all_weights_file(model_path)
if use_safetensors:
load_config = fd_config.load_config if fd_config else None
extra_config = load_config.model_loader_extra_config if load_config else None
parallel_config = fd_config.parallel_config if fd_config else None
if extra_config is not None and extra_config.get("enable_multithread_load", False):
weights_iterator = multi_thread_safetensors_weights_iterator(
files_list,
max_workers=extra_config.get("num_threads", DEFAULT_NUM_THREADS),
disable_mmap=extra_config.get("disable_mmap", False),
)
else:
if is_layers_are_grouped or (parallel_config is not None and parallel_config.tensor_parallel_size == 1):
weights_iterator = safetensors_weights_iterator(files_list)
else:
weights_iterator = safetensors_weights_iterator_ordered(ordered_weight_map)
else:
weights_iterator = pdparams_weight_iterator(files_list)
yield from weights_iterator
kv_cache_scale_json_path = Path(model_path) / "kv_cache_scale.json"
if kv_cache_scale_json_path.exists():
yield from kv_cache_scale_iterator(str(kv_cache_scale_json_path))
def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
weight_cache_context = contextlib.nullcontext()
weight_cache_dir = None
enable_cache = False
if envs.FD_ENABLE_MODEL_LOAD_CACHE and fd_config.quant_config is not None:
model_weight_cache_path = os.path.join(fd_config.model_config.model, weight_cache_path)
# model_type + quantization + tp_size + ep_size
weight_cache_key = "_".join(
[
fd_config.model_config.model_type,
fd_config.quant_config.name(),
str(fd_config.parallel_config.tensor_parallel_size),
str(fd_config.parallel_config.expert_parallel_size),
]
)
# only support tp now
hash_key = hashlib.md5(pickle.dumps(weight_cache_key)).hexdigest()
weight_cache_dir = os.path.join(model_weight_cache_path, hash_key)
if os.path.exists(weight_cache_dir):
logger.info(
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
)
enable_cache = True
weight_cache_context = multi_switch_config_context(
(fd_config.quant_config, "is_checkpoint_bf16", False),
)
return enable_cache, weight_cache_dir, weight_cache_context
def save_model(model_arg_name="model", config_arg_name="fd_config"):
@measure_time("Model saving")
def _save_model(model_dict, weight_cache_dir):
# Note: ProcessGroupNCCL do not support deepcopy protocol, we made modifications here.
paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self
paddle.distributed.communication.group.Group.to_json = lambda self: repr(self)
paddle.save(model_dict, weight_cache_dir)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
fd_config = bound_args.arguments.get(config_arg_name, None)
model = bound_args.arguments.get(model_arg_name, None)
enable_cache, weight_cache_dir, _ = is_weight_cache_enabled(fd_config)
assert fd_config is not None, "fd_config cannot be None"
assert model is not None, "model cannot be None"
if enable_cache:
tp_weight_cache_dir = os.path.join(
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
)
context = multi_switch_config_context((fd_config.model_config, "model", tp_weight_cache_dir))
else:
context = contextlib.nullcontext()
with context:
result = func(*args, **kwargs)
if envs.FD_ENABLE_MODEL_LOAD_CACHE:
if not (
fd_config.quant_config is not None and getattr(fd_config.quant_config, "is_checkpoint_bf16", False)
):
# Save cache only for dynamic quantization
return result
if weight_cache_dir is None:
return result
tp_weight_cache_dir = os.path.join(
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
)
if not os.path.exists(tp_weight_cache_dir):
logger.info(f"Saving model to {tp_weight_cache_dir}")
os.makedirs(
tp_weight_cache_dir,
exist_ok=True,
)
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
else:
reason = "weights already cached" if envs.FD_ENABLE_MODEL_LOAD_CACHE else "cache disabled"
logger.info(f"Skip saving ,{reason}")
return result
return wrapper
return decorator
def measure_time(prefix: str = "Model loading"):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
time_before = time.time()
result = func(*args, **kwargs)
time_after = time.time()
logger.info(f"{prefix} took {time_after - time_before:.3f} seconds")
return result
return wrapper
return decorator
def load_reordered_experts(model_path: str, key_name: str):
from safetensors import safe_open
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_list = json.load(f)["weight_map"]
safetensor_path = os.path.join(model_path, weight_list[key_name])
with safe_open(safetensor_path, framework="np", device="cpu") as f:
if key_name in f.keys():
weight = f.get_tensor(key_name)
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
return weight
def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfig, return_numpy: bool = False):
"""
load ep checkpoint
"""
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_list = json.load(f)["weight_map"]
filtered_map = {k: v for k, v in weight_list.items() if ".experts." not in k}
num_local_ffn_keys = []
from itertools import chain
def get_expert_ranges(fd_config):
"""
Generate expert index ranges based on configuration parameters
This function is primarily used in Mixture-of-Experts (MoE) models to generate
expert index ranges according to configuration parameters. When moe_num_experts
is a list in the fd_config, it returns a chained combination of two ranges, otherwise
returns a single range.
Args:
fd_config: FastDeploy Configuration object
Returns:
If moe_num_experts is a list:
Returns a chained combination (chain object) of two ranges:
1. Base range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
2. Offset range: [base_range.start + moe_num_experts[0], base_range.stop + moe_num_experts[0])
Else:
Returns single range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
"""
base_range = range(
fd_config.parallel_config.num_experts_start_offset,
fd_config.parallel_config.num_experts_start_offset + fd_config.parallel_config.num_experts_per_rank,
)
if isinstance(fd_config.model_config.moe_num_experts, list):
return chain(
base_range,
range(
base_range.start + fd_config.model_config.moe_num_experts[0],
base_range.stop + fd_config.model_config.moe_num_experts[0],
),
)
return base_range
prefix_layer_name = (
"mtp_block" if getattr(fd_config.speculative_config, "model_type", "main") == "mtp" else "layers"
)
moe_num_experts = fd_config.model_config.moe_num_experts
if isinstance(moe_num_experts, list):
moe_num_experts = moe_num_experts[0]
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
for j in get_expert_ranges(fd_config):
# Map redundant expert IDs back to actual expert IDs for weight loading
j = j % moe_num_experts
up_gate_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight"
up_gate_proj_quant_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
down_proj_quant_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.quant_weight"
up_gate_proj_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
down_proj_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.weight_scale"
down_proj_in_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.down_proj.activation_scale"
# single up_gate_proj.activation_scale for all mlp.experts
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.up_gate_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_key)
num_local_ffn_keys.append(down_proj_key)
num_local_ffn_keys.append(up_gate_proj_quant_key)
num_local_ffn_keys.append(down_proj_quant_key)
num_local_ffn_keys.append(up_gate_proj_scale_key)
num_local_ffn_keys.append(down_proj_scale_key)
num_local_ffn_keys.append(down_proj_in_scale_key)
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
num_experts = fd_config.model_config.moe_num_experts
if isinstance(num_experts, list):
num_experts = num_experts[0]
for j in range(num_experts):
up_gate_proj_in_scale_key = f"ernie.{prefix_layer_name}.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
for k in num_local_ffn_keys:
if k in weight_list:
filtered_map[k] = weight_list[k]
if fd_config.parallel_config.tensor_parallel_size > 1:
no_tp_action_keys = copy.deepcopy(num_local_ffn_keys)
if fd_config.parallel_config.use_sequence_parallel_moe:
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
no_tp_keys = [
f"ernie.{prefix_layer_name}.{i}.self_attn.o_proj.weight",
f"ernie.{prefix_layer_name}.{i}.self_attn.o_proj.bias",
]
for k in no_tp_keys:
if k in weight_list:
no_tp_action_keys.append(k)
tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config)
new_actions = {k: v for k, v in tp_actions.items() if k not in no_tp_action_keys}
state_dict = {}
# Get all safetensor file paths that need to be opened
safetensor_paths = set(filtered_map.values())
# Open each safetensor file sequentially with progress bar
for safetensor_path in tqdm(safetensor_paths, desc="Loading safetensor files", unit="file"):
with safe_open(
os.path.join(model_path, safetensor_path),
framework="np",
device="cpu",
) as f:
# Check if this file contains keys from filtered_map
for k in filtered_map:
if filtered_map[k] == safetensor_path and k in f.keys():
weight = f.get_tensor(k)
if fd_config.parallel_config.tensor_parallel_size > 1:
if k in new_actions:
weight = new_actions[k](weight)
if not return_numpy:
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
state_dict[k] = weight
return state_dict
def kv_cache_scale_iterator(kv_cache_scale_json_path):
"""
kv_cache_scale_iterator
"""
with open(kv_cache_scale_json_path, "r") as f:
data = json.load(f)
for key, value in data.items():
scale_tensor = paddle.to_tensor(value, dtype=paddle.get_default_dtype()) * 448.0
yield key, scale_tensor
def safetensors_weights_iterator(safe_tensor_list: list[str]):
"""
safetensors_weights_iterator
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with safe_open(st_file, framework="paddle", device="cpu") as f:
for name in f.keys():
param = f.get_tensor(name)
yield name, param
def multi_thread_safetensors_weights_iterator(safe_tensor_list, max_workers: int = 4, disable_mmap: bool = False):
"""
Iterate over safetensors weights using multi-threaded loading.
Args:
safe_tensor_list: List of safetensors file paths to load.
max_workers: Maximum number of threads for concurrent loading. Defaults to 4.
disable_mmap: If True, load files into memory directly instead of using memory-mapped
files. Useful when mmap is not supported or causes issues.
Yields:
Tuple[str, paddle.Tensor]: Weight name and corresponding tensor.
"""
try:
enable_tqdm = dist.get_rank() == 0
except Exception:
enable_tqdm = True
def _load_file(st_file: str):
if disable_mmap:
with open(st_file, "rb") as f:
result = safetensors.paddle.load(f.read())
else:
result = safetensors.paddle.load_file(st_file, device="cpu")
return result
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_load_file, st_file) for st_file in safe_tensor_list]
if enable_tqdm:
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(safe_tensor_list),
desc="Multi-thread loading shards",
disable=not enable_tqdm,
)
else:
futures_iter = concurrent.futures.as_completed(futures)
for future in futures_iter:
state_dict = future.result()
for name, param in state_dict.items():
yield name, param
def safetensors_weights_iterator_ordered(ordered_weight_map: dict[str, str]):
"""
safetensors_weights_iterator_ordered
"""
with ExitStack() as stack:
current_file = None
current_handle = None
for key, st_file in tqdm(
ordered_weight_map.items(),
desc="Loading safetensors weights",
):
if st_file != current_file:
stack.close()
current_handle = stack.enter_context(safe_open(st_file, framework="paddle", device="cpu"))
current_file = st_file
yield key, current_handle.get_tensor(key)
def fast_weights_iterator(safe_tensor_list: list[str]):
"""
paddleformers' iterator for safetensors
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with fast_safe_open(st_file, framework="np") as f:
for name in f.keys():
param_slice = f.get_slice(name)
yield name, param_slice
def load_pre_sharded_checkpoint(model_path: str, local_rank: int):
"""
load_pre_sharded_checkpoint
"""
state_dict = {}
weights_iterator = get_weight_iterator(os.path.join(model_path, f"rank{local_rank}"))
for name, weight in weights_iterator:
state_dict[name] = weight.clone()
return state_dict
def get_all_weights_file(model_path: str):
"""
get_all_safetensors
"""
model_path = Path(model_path)
use_safetensors = True
files_list = [str(file) for file in model_path.glob("*.pdparams") if file.name != "scheduler.pdparams"]
if len(files_list) > 0:
ordered_weight_map = {}
use_safetensors = False
# dont care about the order of the files
return files_list, {}, use_safetensors, False
else:
safe_model_path = model_path / "model.safetensors"
if safe_model_path.exists():
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = sorted(f.keys(), key=natural_key)
ordered_weight_map = {key: "model.safetensors" for key in key_name_list}
is_layers_are_grouped = True
files_list = [str(safe_model_path)]
return files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped
else:
index_file = model_path / "model.safetensors.index.json"
with index_file.open("r") as f:
weight_map = json.load(f)["weight_map"]
keys = list(weight_map.keys())
values = list(weight_map.values())
is_keys_orders = layers_are_grouped(keys)
is_values_naturally_ordered = values_are_naturally_ordered(values)
is_layers_are_grouped = is_keys_orders and is_values_naturally_ordered
ordered_weight_map = {
key: str(model_path / weight_map[key]) for key in sorted(weight_map.keys(), key=natural_key)
}
weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map}
files_list = sorted(weight_files_in_index)
return files_list, ordered_weight_map, use_safetensors, is_layers_are_grouped
def deal_state_dict(state_dict):
"""deal_state_dict"""
device = paddle.CUDAPinnedPlace()
for name, src in state_dict.items():
if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace):
dst = src._copy_to(device, True)
dst_tensor = dst.value().get_tensor()
src_tensor = src.value().get_tensor()
src_tensor._clear()
src_tensor._share_data_with(dst_tensor)
def load_kv_cache_scale(fd_config, state_dict):
file_path = fd_config.model_config.kv_cache_quant_scale_path
prefix_layer_name = fd_config.model_config.prefix_layer_name
if os.path.exists(file_path):
with open(file_path, "r") as f:
data = json.load(f)
for i in range(fd_config.model_config.num_hidden_layers):
k_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachek_matmul.activation_scale"
v_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachev_matmul.activation_scale"
k_scale = data[k_scale_name]
k_scale_tensor = paddle.to_tensor(k_scale, dtype=paddle.get_default_dtype())
state_dict[k_scale_name] = k_scale_tensor * 448.0
v_scale = data[v_scale_name]
v_scale_tensor = paddle.to_tensor(v_scale, dtype=paddle.get_default_dtype())
state_dict[v_scale_name] = v_scale_tensor * 448.0
logger.info(f"Loaded kv cache scales for layer {i}.")
else:
logger.warning(f"No kv_cache_scale.json found at {file_path}, skipping...")
def load_composite_checkpoint(
model_path: str,
cls: PretrainedModel,
fd_config: FDConfig,
return_numpy=True,
):
"""
# This method supports loading model weights under three parallelism strategies:
# 1. Expert Parallel (EP)
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
if fd_config.parallel_config.use_ep:
state_dict = load_ep_checkpoint(cls, model_path, fd_config, return_numpy=True)
else:
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
)
else:
fd_config.model_config.pretrained_config.use_sequence_parallel_moe = (
fd_config.parallel_config.use_sequence_parallel_moe
)
# NOTE: for very big model, cpu will be out of memory
state_dict = load_tp_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
)
if not state_dict:
raise ValueError("weight not found in state_dict !")
if hasattr(fd_config.quant_config, "kv_cache_quant_type"):
kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type
if kv_cache_quant_type == "float8_e4m3fn":
load_kv_cache_scale(fd_config, state_dict)
return state_dict