-
Notifications
You must be signed in to change notification settings - Fork 327
Expand file tree
/
Copy pathhf_load_utils.py
More file actions
executable file
·73 lines (64 loc) · 2.91 KB
/
hf_load_utils.py
File metadata and controls
executable file
·73 lines (64 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import os
import gc
from safetensors import safe_open
from tqdm import tqdm
import lightllm.utils.petrel_helper as utils
from lightllm.utils.dist_utils import get_current_device_id
from queue import Queue
from threading import Thread
def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None):
# fix bug for 多线程加载的时候,每个线程内部的cuda device 会切回 0, 修改后来保证不会出现bug
import torch.distributed as dist
torch.cuda.set_device(get_current_device_id())
if use_safetensors:
weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
weights = {k: weights.get_tensor(k) for k in weights.keys()}
else:
weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu")
if pre_post_layer is not None:
pre_post_layer.load_hf_weights(weights)
if transformer_layer_list is not None:
for layer in transformer_layer_list:
layer.load_hf_weights(weights)
del weights
gc.collect()
def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
if isinstance(data_type, str):
data_type = torch.float16 if data_type == "fp16" else torch.float32
if pre_post_layer is not None:
assert pre_post_layer.data_type_ == data_type, "type is not right"
if transformer_layer_list is not None:
assert transformer_layer_list[0].data_type_ == data_type, "type is not right"
if weight_dict:
if pre_post_layer is not None:
pre_post_layer.load_hf_weights(weight_dict)
if transformer_layer_list is not None:
for layer in transformer_layer_list:
layer.load_hf_weights(weight_dict)
del weight_dict
return
use_safetensors = True
files = utils.PetrelHelper.list(weight_dir, extension="all")
candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files))
if len(candidate_files) == 0:
use_safetensors = False
candidate_files = list(filter(lambda x: x.endswith(".bin"), files))
assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights."
from functools import partial
from multiprocessing.pool import ThreadPool as Pool
partial_func = partial(
load_func,
use_safetensors=use_safetensors,
pre_post_layer=pre_post_layer,
transformer_layer_list=transformer_layer_list,
weight_dir=weight_dir,
) # noqa
worker = int(os.environ.get("LOADWORKER", 1))
with Pool(worker) as p:
iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1)
desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers"
iterator = tqdm(iterator, total=len(candidate_files), desc=desc_str)
for _ in iterator:
pass
return