Skip to content

Commit 78958bc

Browse files
committed
[v1loader]Reduce EB300B model loading time (PaddlePaddle#3700)
* speed up eb45 * update
1 parent d1d063e commit 78958bc

4 files changed

Lines changed: 45 additions & 36 deletions

File tree

fastdeploy/model_executor/layers/linear.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ def __init__(
415415
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
416416
model_format = getattr(param, "model_format", "")
417417
if model_format == "torch":
418+
loaded_weight = get_tensor(loaded_weight)
418419
loaded_weight = loaded_weight.transpose([1, 0])
419420
output_dim = getattr(param, "output_dim", None)
420421
assert output_dim is not None
@@ -446,7 +447,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
446447
shard_offset = self.local_rank * block_size
447448
shard_size = (self.local_rank + 1) * block_size
448449
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
449-
450+
loaded_weight = get_tensor(loaded_weight)
450451
if not param._is_initialized():
451452
param.initialize()
452453
param_shard_size = output_size // 2
@@ -548,6 +549,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
548549
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
549550
model_format = getattr(param, "model_format", "")
550551
if model_format == "torch":
552+
loaded_weight = get_tensor(loaded_weight)
551553
loaded_weight = loaded_weight.transpose([1, 0])
552554
if loaded_shard_id is None:
553555
# Loaded weight is already fused on disk
@@ -568,12 +570,13 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
568570
# Tensor parallelism splits the weight along the output_dim
569571
if self.nranks != 1:
570572
block_size = self._get_shard_size_mapping(loaded_shard_id)
571-
dim = -1 if output_dim else 0
572573
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
573574
shard_offset = shard_id * block_size
574575
shard_size = (shard_id + 1) * block_size
575576
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
576577

578+
loaded_weight = get_tensor(loaded_weight)
579+
577580
if not param._is_initialized():
578581
param.initialize()
579582

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,24 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
176176

177177
if shard_id is None:
178178
# 1.gate up fused in disk
179+
model_format = getattr(param, "model_format", "")
180+
is_torch_model = model_format == "torch"
179181
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
180-
shard_offsets = [
181-
# (shard_id, shard_offset, shard_size)
182-
("gate", 0, output_size // 2 * self.tp_size),
183-
("up", output_size // 2 * self.tp_size, output_size // 2 * self.tp_size),
184-
]
185-
for shard_id, shard_offset, shard_size in shard_offsets:
186-
loaded_weight_shard = slice_fn(
187-
loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size
188-
)
189-
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
182+
per_rank = output_size // 2
183+
start = self.tp_rank * per_rank
184+
loaded_weight_shard_gate = slice_fn(
185+
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
186+
)
187+
self._load_gate_up_weight(
188+
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
189+
)
190+
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
191+
loaded_weight_shard_up = slice_fn(
192+
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
193+
)
194+
self._load_gate_up_weight(
195+
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
196+
)
190197
else:
191198
# 2.gate up splited in disk
192199
assert shard_id in ["gate", "down", "up"]
@@ -198,22 +205,23 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
198205
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
199206
)
200207

201-
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
208+
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
202209
model_format = getattr(param, "model_format", "")
203-
if model_format == "torch":
204-
loaded_weight = loaded_weight.transpose([1, 0])
205-
dim = -1 if shard_dim else 0
206-
if self.tp_size > 1:
210+
is_torch_model = model_format == "torch"
211+
if self.tp_size > 1 and not is_sharded:
212+
tp_shard_dim = is_torch_model ^ shard_dim
213+
weight_dim = -1 if tp_shard_dim else 0
207214
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
208-
size = loaded_weight.shape[dim]
215+
size = loaded_weight.shape[weight_dim]
209216
else:
210-
size = loaded_weight.get_shape()[dim]
217+
size = loaded_weight.get_shape()[weight_dim]
211218
block_size = size // self.tp_size
212219
shard_offset = self.tp_rank * block_size
213220
shard_size = (self.tp_rank + 1) * block_size
214-
loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size)
215-
221+
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
222+
loaded_weight = get_tensor(loaded_weight)
216223
expert_param = param[expert_id - self.expert_id_offset]
224+
dim = -1 if shard_dim else 0
217225
param_shard_size = expert_param.shape[dim] // 2
218226
if shard_id == "gate":
219227
param_shard_offset = 0
@@ -232,36 +240,35 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_
232240
)
233241

234242
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
235-
if current_platform.is_xpu() or current_platform.is_gcu():
236-
if expert_param.shape != loaded_weight.shape:
237-
loaded_weight = loaded_weight.transpose([1, 0])
243+
if expert_param.shape != loaded_weight.shape:
244+
loaded_weight = loaded_weight.transpose([1, 0])
238245
assert expert_param.shape == loaded_weight.shape, (
239246
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
240247
)
241248
expert_param.copy_(loaded_weight, False)
242249

243250
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
244251
model_format = getattr(param, "model_format", "")
245-
if model_format == "torch":
246-
loaded_weight = loaded_weight.transpose([1, 0])
252+
is_torch_model = model_format == "torch"
247253
if self.tp_size > 1 and shard_dim is not None:
248-
dim = -1 if shard_dim else 0
249-
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
254+
tp_shard_dim = is_torch_model ^ shard_dim
255+
dim = -1 if tp_shard_dim else 0
256+
if isinstance(loaded_weight, paddle.Tensor):
250257
size = loaded_weight.shape[dim]
251258
else:
252259
size = loaded_weight.get_shape()[dim]
253260
block_size = size // self.tp_size
254261
shard_offset = self.tp_rank * block_size
255262
shard_size = (self.tp_rank + 1) * block_size
256-
loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size)
263+
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
264+
loaded_weight = get_tensor(loaded_weight)
257265
expert_param = param[expert_id - self.expert_id_offset]
258266
if hasattr(param, "tensor_track"):
259267
# for dyn quant
260268
param.tensor_track.mark(start=0, batch_id=expert_id - self.expert_id_offset)
261-
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
262-
if current_platform.is_xpu or current_platform.is_gcu():
263-
if expert_param.shape != loaded_weight.shape:
264-
loaded_weight = loaded_weight.transpose([1, 0])
269+
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU and opensource weight
270+
if expert_param.shape != loaded_weight.shape:
271+
loaded_weight = loaded_weight.transpose([1, 0])
265272
assert expert_param.shape == loaded_weight.shape, (
266273
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
267274
)

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from tqdm import tqdm
3030

3131
from fastdeploy.config import FDConfig
32-
from fastdeploy.model_executor.layers.utils import get_tensor
3332
from fastdeploy.model_executor.models.tp_utils import (
3433
check_tensor_parallel_prerequisites,
3534
)
@@ -186,8 +185,7 @@ def fast_weights_iterator(safe_tensor_list: list[str]):
186185
with fast_safe_open(st_file, framework="np") as f:
187186
for name in f.keys():
188187
param_slice = f.get_slice(name)
189-
paddle_tensor = get_tensor(param_slice)
190-
yield name, paddle_tensor
188+
yield name, param_slice
191189

192190

193191
def fastsafetensors_weights_iterator(

fastdeploy/model_executor/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
160160
output_dim = getattr(param, "output_dim", None)
161161
model_format = getattr(param, "model_format", "")
162162
if model_format == "torch":
163+
loaded_weight = get_tensor(loaded_weight)
163164
loaded_weight = loaded_weight.transpose([1, 0])
164165
# Tensor parallelism splits the weight along the output_dim
165166
if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1:

0 commit comments

Comments
 (0)