-
Notifications
You must be signed in to change notification settings - Fork 67
Expand file tree
/
Copy pathjiuge_weights_loader.py
More file actions
564 lines (505 loc) · 19.8 KB
/
jiuge_weights_loader.py
File metadata and controls
564 lines (505 loc) · 19.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
# 文件路径: icinfer/engine/weights_loader.py
import os
import json
import torch
import transformers
from typing import Tuple
import math
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import ctypes
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import math
import torch
import transformers
from icinfer.engine.libinfinicore_infer import (
JiugeMetaCStruct,
JiugeWeightsCStruct,
DataType,
create_jiuge_model,
DeviceType,
)
from icinfer.config import Config
import logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class LlamaWeightsNaming:
def input_embd(self):
return "model.embed_tokens.weight"
def output_norm(self):
return "model.norm.weight"
def output_embd(self):
return "lm_head.weight"
def attn_norm(self, i):
return f"model.layers.{i}.input_layernorm.weight"
def attn_q(self, i):
return f"model.layers.{i}.self_attn.q_proj.weight"
def attn_k(self, i):
return f"model.layers.{i}.self_attn.k_proj.weight"
def attn_v(self, i):
return f"model.layers.{i}.self_attn.v_proj.weight"
def attn_o(self, i):
return f"model.layers.{i}.self_attn.o_proj.weight"
def attn_q_b(self, i):
return f"model.layers.{i}.self_attn.q_proj.bias"
def attn_k_b(self, i):
return f"model.layers.{i}.self_attn.k_proj.bias"
def attn_v_b(self, i):
return f"model.layers.{i}.self_attn.v_proj.bias"
def attn_q_norm(self, i):
return f"model.layers.{i}.self_attn.q_norm.weight"
def attn_k_norm(self, i):
return f"model.layers.{i}.self_attn.k_norm.weight"
def ffn_norm(self, i):
return f"model.layers.{i}.post_attention_layernorm.weight"
def gate(self, i):
return f"model.layers.{i}.mlp.gate_proj.weight"
def up(self, i):
return f"model.layers.{i}.mlp.up_proj.weight"
def down(self, i):
return f"model.layers.{i}.mlp.down_proj.weight"
def match(state_dict):
return (
"model.norm.weight" in state_dict
and "model.layers.0.self_attn.q_proj.weight" in state_dict
)
class JiugeMetaFromLlama(JiugeMetaCStruct):
def __init__(self, config, dtype=torch.float16, max_tokens=None):
if dtype == torch.float16:
dt_ = DataType.INFINI_DTYPE_F16
elif dtype == torch.float32:
dt_ = DataType.INFINI_DTYPE_F32
elif dtype == torch.bfloat16:
dt_ = DataType.INFINI_DTYPE_BF16
else:
dt_ = DataType.INFINI_DTYPE_F16
self.scale_input = 1.0
self.scale_output = 1.0
self.scale_o = 1.0
self.scale_down = 1.0
if (
config.model_type in ["fm9g", "minicpm"]
and hasattr(config, "scale_emb")
and hasattr(config, "scale_depth")
and hasattr(config, "dim_model_base")
):
self.scale_input = config.scale_emb
self.scale_output = config.hidden_size // config.dim_model_base
self.scale_o = config.scale_depth / math.sqrt(config.num_hidden_layers)
self.scale_down = config.scale_depth / math.sqrt(config.num_hidden_layers)
dim_model_base = (
config.dim_model_base if hasattr(config, "dim_model_base") else config.hidden_size
)
# Load longrope configuration
rope_type = 0 # 0 = standard, 1 = longrope
original_max_position_embeddings = 0
short_factor_ptr = None
long_factor_ptr = None
self._short_factor_array = None # Keep reference to prevent GC
self._long_factor_array = None # Keep reference to prevent GC
# Handle both dict and object config
if hasattr(config, "rope_scaling"):
rope_scaling = config.rope_scaling
elif isinstance(config, dict) and "rope_scaling" in config:
rope_scaling = config["rope_scaling"]
else:
rope_scaling = {}
if isinstance(rope_scaling, dict):
rope_scaling_type = rope_scaling.get("rope_type") or rope_scaling.get("type", "")
if rope_scaling_type == "longrope":
rope_type = 1
original_max_position_embeddings = rope_scaling.get(
"original_max_position_embeddings",
getattr(config, "original_max_position_embeddings", 0) if not isinstance(config, dict) else config.get("original_max_position_embeddings", 0)
)
short_factor_list = rope_scaling.get("short_factor", [])
long_factor_list = rope_scaling.get("long_factor", [])
if short_factor_list and long_factor_list:
# Convert to ctypes arrays
half_dh = (config.hidden_size // config.num_attention_heads) // 2
if len(short_factor_list) == half_dh and len(long_factor_list) == half_dh:
self._short_factor_array = (c_float * half_dh)(*short_factor_list)
self._long_factor_array = (c_float * half_dh)(*long_factor_list)
short_factor_ptr = ctypes.cast(self._short_factor_array, POINTER(c_float))
long_factor_ptr = ctypes.cast(self._long_factor_array, POINTER(c_float))
else:
import logging
logger = logging.getLogger(__name__)
logger.warning(
f"Longrope factor arrays have wrong length: "
f"short={len(short_factor_list)}, long={len(long_factor_list)}, expected={half_dh}"
)
super().__init__(
dt_logits=dt_,
nlayer=config.num_hidden_layers,
d=config.hidden_size,
nh=config.num_attention_heads,
nkvh=(
config.num_key_value_heads
if hasattr(config, "num_key_value_heads")
else config.num_attention_heads
),
dh=config.hidden_size // config.num_attention_heads,
di=config.intermediate_size,
dctx=(config.max_position_embeddings if max_tokens is None else max_tokens),
dvoc=config.vocab_size,
kvcache_block_size=config.kvcache_block_size,
dim_model_base=dim_model_base,
epsilon=config.rms_norm_eps,
theta=(config.rope_theta if hasattr(config, "rope_theta") else 100000.0),
end_token=2,
rope_type=rope_type,
original_max_position_embeddings=original_max_position_embeddings,
short_factor=short_factor_ptr,
long_factor=long_factor_ptr,
)
self.torch_dtype_logits = dtype
class JiugeWeightsImpl(JiugeWeightsCStruct):
def __init__(
self,
meta,
naming,
state_dict,
torch_dt_mat=torch.float16,
torch_dt_norm=torch.float32,
ndev=1,
transpose_weight=True,
):
nlayer = meta.nlayer
nh = meta.nh
nkvh = meta.nkvh
dh = meta.dh
d = meta.d
di = meta.di
scale_input = meta.scale_input
scale_output = meta.scale_output
scale_o = meta.scale_o
scale_down = meta.scale_down
assert nh % nkvh == 0
assert nh % ndev == 0
assert nkvh % ndev == 0
assert di % ndev == 0
torch_dt_logits = meta.torch_dtype_logits
if torch_dt_mat == torch.float16:
self.dt_mat = DataType.INFINI_DTYPE_F16
elif torch_dt_mat == torch.float32:
self.dt_mat = DataType.INFINI_DTYPE_F32
elif torch_dt_mat == torch.bfloat16:
self.dt_mat = DataType.INFINI_DTYPE_BF16
else:
raise ValueError("Unsupported proj weight data type")
if torch_dt_norm == torch.float16:
self.dt_norm = DataType.INFINI_DTYPE_F16
elif torch_dt_norm == torch.float32:
self.dt_norm = DataType.INFINI_DTYPE_F32
elif torch_dt_norm == torch.bfloat16:
self.dt_norm = DataType.INFINI_DTYPE_BF16
else:
raise ValueError("Unsupported norm weight data type")
input_embd_naming = (
naming.input_embd()
if naming.input_embd() in state_dict
else naming.output_embd()
)
output_embd_naming = (
naming.output_embd()
if naming.output_embd() in state_dict
else naming.input_embd()
)
self.transpose_linear_weights = 1 if transpose_weight else 0
self.nlayer = nlayer
self.input_embd_tensor = (
state_dict[input_embd_naming].to(torch_dt_logits) * scale_input
)
self.input_embd = self.input_embd_tensor.data_ptr()
self.output_norm_tensor = (
state_dict[naming.output_norm()].to(torch_dt_norm)
)
self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
if not transpose_weight:
self.output_embd_tensor = self.output_embd_tensor.transpose(
0, 1
).contiguous()
self.output_embd = self.output_embd_tensor.data_ptr()
self.attn_norm_tensors = [
state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer)
]
self.attn_norm_ptrs = [
self.attn_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs)
def qkv_slices(_i):
_Q = (
state_dict[naming.attn_q(_i)]
.reshape([nh, 2, dh // 2, d])
.transpose(1, 2)
)
_K = (
state_dict[naming.attn_k(_i)]
.reshape([nkvh, 2, dh // 2, d])
.transpose(1, 2)
)
_V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d])
_result = []
_nh = nh // ndev
_nkvh = nkvh // ndev
for _idev in range(ndev):
_result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :])
_result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :])
_result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
return _result
self.qkv_tensor = [
torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer)
]
if not transpose_weight:
for i in range(nlayer):
self.qkv_tensor[i] = (
self.qkv_tensor[i]
.reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d)
.transpose(1, 2)
.contiguous()
)
self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs)
def qkv_b_slices(_i):
_QB = (
state_dict[naming.attn_q_b(_i)]
.reshape([nh, 2, dh // 2])
.transpose(1, 2)
)
_KB = (
state_dict[naming.attn_k_b(_i)]
.reshape([nkvh, 2, dh // 2])
.transpose(1, 2)
)
_VB = state_dict[naming.attn_v_b(_i)].reshape([nkvh, dh // 2, 2])
_result = []
_nh = nh // ndev
_nkvh = nkvh // ndev
for _idev in range(ndev):
_result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :].flatten())
_result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten())
_result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten())
return _result
if naming.attn_q_b(0) in state_dict:
self.qkv_b_tensors = [
torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer)
]
self.qkv_b_tensor_ptrs = [
self.qkv_b_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_qkv_b = (c_void_p * nlayer)(*self.qkv_b_tensor_ptrs)
else:
self.attn_qkv_b = None
if naming.attn_q_norm(0) in state_dict:
self.attn_q_norm_tensors = [
state_dict[naming.attn_q_norm(i)]
.reshape([2, dh // 2])
.transpose(0, 1)
.contiguous()
.to(torch_dt_norm)
for i in range(nlayer)
]
self.attn_q_norm_ptrs = [
self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs)
self.attn_k_norm_tensors = [
state_dict[naming.attn_k_norm(i)]
.reshape([2, dh // 2])
.transpose(0, 1)
.contiguous()
.to(torch_dt_norm)
for i in range(nlayer)
]
self.attn_k_norm_ptrs = [
self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs)
else:
self.attn_q_norm = None
self.attn_k_norm = None
self.attn_o_tensor = [
(
state_dict[naming.attn_o(i)]
.to(torch_dt_mat)
.reshape([d, ndev, nh // ndev * dh])
.transpose(0, 1)
.contiguous()
if transpose_weight
else state_dict[naming.attn_o(i)]
.transpose(0, 1)
.to(torch_dt_mat)
.contiguous()
)
* scale_o
for i in range(nlayer)
]
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs)
self.ffn_norm_tensors = [
state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer)
]
self.ffn_norm_ptrs = [
self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs)
def gate_up_slices(_i):
_result = []
_di = di // ndev
for _idev in range(ndev):
_start = _idev * _di
_end = (_idev + 1) * _di
_result.append(state_dict[naming.gate(_i)][_start:_end, :])
_result.append(state_dict[naming.up(_i)][_start:_end, :])
return _result
self.gate_up_tensors = [
torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer)
]
if not transpose_weight:
for i in range(nlayer):
self.gate_up_tensors[i] = (
self.gate_up_tensors[i]
.reshape(ndev, 2 * di // ndev, d)
.transpose(1, 2)
.contiguous()
)
self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)]
self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs)
self.ffn_down_tensor = [
(
state_dict[naming.down(i)]
.to(torch_dt_mat)
.reshape([d, ndev, di // ndev])
.transpose(0, 1)
.contiguous()
if transpose_weight
else state_dict[naming.down(i)]
.transpose(0, 1)
.to(torch_dt_mat)
.contiguous()
)
* scale_down
for i in range(nlayer)
]
self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs)
def load_weights_to_cpu(
config: Config, device: DeviceType
) -> Tuple[JiugeMetaCStruct, JiugeWeightsImpl]:
"""
复用旧 infiniinfer 的权重加载逻辑。
在 CPU 上加载模型权重和配置,并将其转换为 C++ 兼容的结构体。
"""
def load_all_safetensors_from_dir(dir_path_: str):
tensors_ = {}
dir_path_ = Path(dir_path_)
for file in sorted(dir_path_.glob("*.safetensors")):
data_ = safetensors.safe_open(file, "pt")
for name_ in data_.keys():
tensors_[name_] = data_.get_tensor(name_)
return tensors_
max_tokens = config.max_model_len
model_dir_path = config.model_path
ndev = config.tensor_parallel_size
hf_config = config.hf_config
print("Loading model weights to host...")
load_start_time = time.time()
transpose_weight = (
device != DeviceType.DEVICE_TYPE_ASCEND
) # y = xW is faster than y=xW^T on Ascend
if "llama" == hf_config.model_type:
model = transformers.LlamaForCausalLM.from_pretrained(
model_dir_path,
device_map="cpu",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
load_statets_time = time.time()
meta = JiugeMetaFromLlama(hf_config, max_tokens=max_tokens)
weights = JiugeWeightsImpl(
meta,
LlamaWeightsNaming(),
model.state_dict(),
ndev=ndev,
transpose_weight=transpose_weight,
)
elif "fm9g" == hf_config.model_type:
logger.info(f"fm9g load start.")
# )
model = transformers.AutoModelForCausalLM.from_pretrained(
model_dir_path,
device_map="cpu",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
logger.info(f"load over.")
load_statets_time = time.time()
meta = JiugeMetaFromLlama(hf_config, max_tokens=max_tokens)
weights = JiugeWeightsImpl(
meta,
LlamaWeightsNaming(),
model.state_dict(),
ndev=ndev,
transpose_weight=transpose_weight,
)
elif "fm9g7b" == hf_config.model_type:
logger.info(f"fm9g7b load start.")
model = transformers.AutoModelForCausalLM.from_pretrained(
model_dir_path,
device_map="cpu",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
logger.info(f"load over.")
load_statets_time = time.time()
meta = JiugeMetaFromLlama(hf_config, max_tokens=max_tokens)
weights = JiugeWeightsImpl(
meta,
LlamaWeightsNaming(),
model.state_dict(),
ndev=ndev,
transpose_weight=transpose_weight,
)
elif "qwen2" == hf_config.model_type:
state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict):
meta = JiugeMetaFromLlama(hf_config, max_tokens=max_tokens)
weights = JiugeWeightsImpl(
meta,
LlamaWeightsNaming(),
state_dict,
ndev=ndev,
transpose_weight=transpose_weight,
)
else:
raise ValueError("Unsupported model architecture")
load_end_time = time.time()
logger.info(
f"Time overall used: {load_end_time - load_start_time:.3f}s, "
f"load_states_time: {load_statets_time - load_start_time:.3f}s, "
f"load_weights_impl_time: {load_end_time - load_statets_time:.3f}s"
)
logger.info(f"Creating model on {ndev} devices...")
load_start_time = time.time()
print("Weights loaded to CPU successfully.")
return meta, weights
def load_model(config: Config, device: DeviceType):
ndev = config.tensor_parallel_size
meta, weights = load_weights_to_cpu(config, device)
dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
model = create_jiuge_model(
byref(meta),
byref(weights),
device,
ndev,
dev_ids,
)
return model, meta