Skip to content

Commit bc7152b

Browse files
committed
support xpu w4a8c8+skip quant
1 parent bc31f10 commit bc7152b

File tree

8 files changed

+183
-24
lines changed

8 files changed

+183
-24
lines changed

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,11 @@ def allocated_slots(self, request: Request):
235235
return len(request.block_tables) * self.config.cache_config.block_size
236236

237237
def get_new_block_nums(self, request: Request, num_new_tokens: int):
238+
# Account for preallocated blocks that haven't been added to block_tables yet
239+
preallocated_count = len(getattr(request, 'preallocated_blocks', []))
238240
block_num = (
239241
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
240-
) // self.config.cache_config.block_size - len(request.block_tables)
242+
) // self.config.cache_config.block_size - len(request.block_tables) - preallocated_count
241243

242244
if self.config.speculative_config.method is not None:
243245
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
@@ -800,8 +802,14 @@ def get_enough_request(request, scheduled_reqs):
800802
self.allocated_slots(request) - request.num_total_tokens
801803
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
802804
):
805+
# First, consume any preallocated blocks before allocating new ones
806+
preallocated = getattr(request, 'preallocated_blocks', [])
807+
if preallocated:
808+
request.block_tables.extend(preallocated)
809+
request.preallocated_blocks = []
810+
scheduled_reqs.append(self._prepare_decode_task(request))
803811
# Allocation for next decoding blocks
804-
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
812+
elif self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
805813
llm_logger.debug(
806814
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
807815
)
@@ -911,6 +919,12 @@ def _allocate_decode_and_extend():
911919
request.block_tables.extend(
912920
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
913921
)
922+
# Merge preallocated blocks (from PD disaggregation) into block_tables
923+
# so the attention kernel can access all reserved blocks.
924+
preallocated = getattr(request, 'preallocated_blocks', [])
925+
if preallocated:
926+
request.block_tables.extend(preallocated)
927+
request.preallocated_blocks = []
914928
# Prepare prefill task
915929
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
916930
else: # Not enough blocks to allocate, trigger preemption
@@ -920,6 +934,11 @@ def _allocate_decode_and_extend():
920934
request.block_tables.extend(
921935
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
922936
)
937+
# Merge preallocated blocks (from PD disaggregation) into block_tables
938+
preallocated = getattr(request, 'preallocated_blocks', [])
939+
if preallocated:
940+
request.block_tables.extend(preallocated)
941+
request.preallocated_blocks = []
923942
# Prepare prefill task
924943
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
925944
token_budget -= num_new_tokens
@@ -1403,9 +1422,10 @@ def preallocate_resource_in_d(self, request: Request):
14031422
"""
14041423
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
14051424
request.need_prefill_tokens = len(request.prompt_token_ids)
1406-
need_prealloc_prefill_blocks = (
1425+
actual_prefill_blocks = (
14071426
request.need_prefill_tokens + self.config.cache_config.block_size - 1
1408-
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num
1427+
) // self.config.cache_config.block_size
1428+
need_prealloc_prefill_blocks = actual_prefill_blocks + self.config.cache_config.enc_dec_block_num
14091429

14101430
with self.lock:
14111431
if len(self.waiting) > 0:
@@ -1416,11 +1436,16 @@ def preallocate_resource_in_d(self, request: Request):
14161436
if not self.cache_manager.can_allocate_gpu_blocks(total_need_blocks):
14171437
return False
14181438

1419-
request.block_tables = self.cache_manager.allocate_gpu_blocks(
1439+
all_blocks = self.cache_manager.allocate_gpu_blocks(
14201440
need_prealloc_prefill_blocks, request.request_id
14211441
)
1442+
# Only put the blocks that will actually contain prefilled KV data into block_tables.
1443+
# The extra enc_dec_block_num blocks are pre-reserved for future decode tokens and
1444+
# stored separately to avoid the attention kernel reading uninitialized KV cache data.
1445+
request.block_tables = all_blocks[:actual_prefill_blocks]
1446+
request.preallocated_blocks = all_blocks[actual_prefill_blocks:]
14221447
request.num_computed_tokens = request.need_prefill_tokens
1423-
request.disaggregate_info["block_tables"] = request.block_tables
1448+
request.disaggregate_info["block_tables"] = all_blocks
14241449
allocated_position = self.get_available_position()
14251450
request.idx = allocated_position
14261451
self.tasks_list[request.idx] = request
@@ -1470,6 +1495,12 @@ def add_prefilled_request(self, request_output: RequestOutput):
14701495
self.running.append(request)
14711496

14721497
def _free_blocks(self, request: Request):
1498+
# Also free any preallocated blocks that haven't been consumed yet
1499+
preallocated = getattr(request, 'preallocated_blocks', [])
1500+
if preallocated:
1501+
self.cache_manager.recycle_gpu_blocks(preallocated, request.request_id)
1502+
request.preallocated_blocks = []
1503+
14731504
if self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode":
14741505
self.cache_manager.release_block_ids(request)
14751506
self.cache_manager.recycle_gpu_blocks(

fastdeploy/model_executor/layers/backends/xpu/attention.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,9 @@ def forward_mixed(
181181
cache_v_scale = getattr(layer, "cache_v_scale", None)
182182
cache_k_out_scale = getattr(layer, "cache_k_out_scale", None)
183183
cache_v_out_scale = getattr(layer, "cache_v_out_scale", None)
184-
cache_k_zp = getattr(self, "cache_k_zp", None)
185-
cache_v_zp = getattr(self, "cache_v_zp", None)
184+
# todo: there are accuracy bugs in block_attn with zero_point
185+
cache_k_zp = getattr(layer, "cache_k_zp", None)
186+
cache_v_zp = getattr(layer, "cache_v_zp", None)
186187

187188
if layer.use_qk_norm:
188189
q_norm_weight = layer.q_norm_weight
@@ -220,8 +221,8 @@ def forward_mixed(
220221
cache_v_scale,
221222
cache_k_out_scale,
222223
cache_v_out_scale,
223-
cache_k_zp,
224-
cache_v_zp,
224+
cache_k_zp.astype("bfloat16") if cache_k_zp is not None else None, # for C8 with zero_point
225+
cache_v_zp.astype("bfloat16") if cache_v_zp is not None else None, # for C8 with zero_point
225226
None, # shift
226227
None, # smooth
227228
q_norm_weight,

fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
268268
default_initializer=paddle.nn.initializer.Constant(0),
269269
),
270270
)
271+
set_weight_attrs(
272+
getattr(layer, self.added_scale_attrs[0]),
273+
{
274+
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
275+
},
276+
)
271277
setattr(
272278
layer,
273279
self.added_scale_attrs[1],
@@ -277,6 +283,26 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
277283
default_initializer=paddle.nn.initializer.Constant(0),
278284
),
279285
)
286+
set_weight_attrs(
287+
getattr(layer, self.added_scale_attrs[1]),
288+
{
289+
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
290+
},
291+
)
292+
293+
set_weight_attrs(
294+
layer.up_gate_proj_weight,
295+
{
296+
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
297+
},
298+
)
299+
set_weight_attrs(
300+
layer.down_proj_weight,
301+
{
302+
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
303+
},
304+
)
305+
280306

281307
if self.moe_quant_type in ["w8a8", "w4a8"]:
282308
for in_scale_name in self.added_in_scale_attrs:
@@ -289,6 +315,21 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
289315
default_initializer=paddle.nn.initializer.Constant(0),
290316
),
291317
)
318+
set_weight_attrs(
319+
layer.down_proj_in_scale,
320+
{
321+
"SHARD_ID_TO_SHARDED_DIM": {"gate": None, "up": None, "down": None},
322+
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
323+
},
324+
)
325+
326+
set_weight_attrs(
327+
layer.up_gate_proj_in_scale,
328+
{
329+
"SHARD_ID_TO_SHARDED_DIM": {"gate": None, "up": None, "down": None},
330+
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
331+
},
332+
)
292333

293334
def process_loaded_weights(self, layer: nn.Layer, state_dict):
294335
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)

fastdeploy/model_executor/layers/backends/xpu/quantization/kv_cache.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Optional
1818

19+
import os
1920
import paddle
2021
from paddle import nn
2122

@@ -42,6 +43,7 @@ def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_poi
4243
super().__init__()
4344
self.kv_cache_quant_type = kv_cache_quant_type
4445
self.is_channel_wise = is_channel_wise
46+
self.has_zero_point = has_zero_point
4547

4648
try:
4749
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
@@ -139,6 +141,48 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
139141
scale_shape = [layer.fd_config.model_config.num_key_value_heads]
140142
if self.cache_quant_config.is_channel_wise:
141143
scale_shape = [layer.kv_num_heads * layer.head_dim]
144+
# Custom weight_loader for C8+TP: the safetensors scale/zp shape is
145+
# [1, num_kv_heads, 1, head_dim]. We must split along the kv_heads
146+
# dimension (dim=1), not the last dimension. The default_weight_loader
147+
# treats output_dim as boolean and always splits along dim=-1, which
148+
# is incorrect for 4D tensors where we need to split along dim=1.
149+
fd_config = layer.fd_config
150+
total_kv_heads = fd_config.model_config.num_key_value_heads
151+
tp_size = fd_config.parallel_config.tensor_parallel_size
152+
tp_rank = fd_config.parallel_config.tensor_parallel_rank
153+
def _kv_scale_weight_loader(param, loaded_weight, shard_id=None,
154+
_total_kv_heads=total_kv_heads,
155+
_tp_size=tp_size, _tp_rank=tp_rank):
156+
loaded_weight = get_tensor(loaded_weight).cast("float32")
157+
# TP split along kv_heads dimension
158+
if _tp_size > 1 and not fd_config.load_config.is_pre_sharded:
159+
head_dim = loaded_weight.numel() // _total_kv_heads
160+
loaded_weight = loaded_weight.reshape([_total_kv_heads, head_dim])
161+
kv_heads_per_rank = _total_kv_heads // _tp_size
162+
start = _tp_rank * kv_heads_per_rank
163+
end = start + kv_heads_per_rank
164+
loaded_weight = loaded_weight[start:end, :]
165+
loaded_weight = (127/loaded_weight).reshape(param.shape).cast(param.dtype)
166+
param.copy_(loaded_weight, False)
167+
def _kv_zp_weight_loader(param, loaded_weight, shard_id=None,
168+
_total_kv_heads=total_kv_heads,
169+
_tp_size=tp_size, _tp_rank=tp_rank):
170+
loaded_weight = get_tensor(loaded_weight).cast(param.dtype)
171+
# TP split along kv_heads dimension
172+
if _tp_size > 1 and not fd_config.load_config.is_pre_sharded:
173+
head_dim = loaded_weight.numel() // _total_kv_heads
174+
loaded_weight = loaded_weight.reshape([_total_kv_heads, head_dim])
175+
kv_heads_per_rank = _total_kv_heads // _tp_size
176+
start = _tp_rank * kv_heads_per_rank
177+
end = start + kv_heads_per_rank
178+
loaded_weight = loaded_weight[start:end, :]
179+
loaded_weight = loaded_weight.reshape(param.shape)
180+
param.copy_(loaded_weight, False)
181+
scale_weight_attrs = {**extra_weight_attrs, "weight_loader": _kv_scale_weight_loader}
182+
zp_weight_attrs = {**extra_weight_attrs, "weight_loader": _kv_zp_weight_loader}
183+
else:
184+
scale_weight_attrs = extra_weight_attrs
185+
zp_weight_attrs = extra_weight_attrs
142186

143187
layer.cache_k_scale = layer.create_parameter(
144188
shape=scale_shape,
@@ -154,13 +198,13 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
154198
set_weight_attrs(
155199
layer.cache_k_scale,
156200
{
157-
**extra_weight_attrs,
201+
**scale_weight_attrs,
158202
},
159203
)
160204
set_weight_attrs(
161205
layer.cache_v_scale,
162206
{
163-
**extra_weight_attrs,
207+
**scale_weight_attrs,
164208
},
165209
)
166210

@@ -189,13 +233,13 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
189233
set_weight_attrs(
190234
layer.cache_k_zp,
191235
{
192-
**extra_weight_attrs,
236+
**zp_weight_attrs,
193237
},
194238
)
195239
set_weight_attrs(
196240
layer.cache_v_zp,
197241
{
198-
**extra_weight_attrs,
242+
**zp_weight_attrs,
199243
},
200244
)
201245

@@ -218,11 +262,18 @@ def process_weights_after_loading(self, layer: nn.Layer):
218262
"""
219263
use for loader v1
220264
"""
221-
# cache_k_out_scale is the reciprocal of cache_k_scale
222-
if layer.cache_k_scale._is_initialized():
223-
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale) # cache_k_out_scale
224-
if layer.cache_v_scale._is_initialized():
225-
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
265+
use_c8 = os.getenv("FD_XPU_YIYAN_MODEL", "0") == "1"
266+
if use_c8:
267+
if layer.cache_k_scale._is_initialized():
268+
layer.cache_k_out_scale.set_value(self.cache_quant_config.max_bound / layer.cache_k_scale.cast("float32").reshape_([-1])) # cache_k_out_scale
269+
if layer.cache_v_scale._is_initialized():
270+
layer.cache_v_out_scale.set_value(self.cache_quant_config.max_bound / layer.cache_v_scale.cast("float32").reshape_([-1]))
271+
else:
272+
# cache_k_out_scale is the reciprocal of cache_k_scale
273+
if layer.cache_k_scale._is_initialized():
274+
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale) # cache_k_out_scale
275+
if layer.cache_v_scale._is_initialized():
276+
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
226277

227278
def apply(self, layer):
228279
"""

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,19 @@ def __init__(
260260
tp_size={self.tp_size}."
261261
)
262262

263+
def _load_in_scale_weight(self, param, expert_id, loaded_weight):
264+
# only spport ernie now
265+
expert_param = param[expert_id - self.expert_id_offset]
266+
loaded_weight = get_tensor(loaded_weight)
267+
if len(expert_param.shape) != len(loaded_weight.shape):
268+
loaded_weight = loaded_weight.reshape(expert_param.shape)
269+
assert expert_param.shape == loaded_weight.shape, (
270+
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
271+
)
272+
if expert_param.dtype != loaded_weight.dtype:
273+
loaded_weight = loaded_weight.cast(expert_param.dtype)
274+
param[expert_id - self.expert_id_offset].copy_(loaded_weight, False)
275+
263276
def weight_loader(
264277
self,
265278
param,
@@ -292,9 +305,15 @@ def weight_loader(
292305
if weight_need_transpose:
293306
loaded_weight = loaded_weight.transpose([1, 0])
294307

308+
if SHARD_ID_TO_SHARDED_DIM["gate"] is None and SHARD_ID_TO_SHARDED_DIM["up"] is None:
309+
# in scale
310+
self._load_in_scale_weight(param, expert_id, loaded_weight)
311+
return
312+
295313
if shard_id is None:
296314
# 1.gate up fused in disk
297-
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
315+
shard_param = param[expert_id - self.expert_id_offset]
316+
output_size = shard_param.shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
298317
shard_offsets = [
299318
# (shard_id, shard_offset, shard_size)
300319
("gate", 0, output_size // 2 * self.tp_size),

fastdeploy/model_executor/layers/quantization/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
173173
from .weight_only import WeightOnlyConfig, WINT4Config, WINT8Config
174174
from .wfp8afp8 import WFP8AFP8Config
175175
from .wint2 import WINT2Config
176+
from fastdeploy.platforms import current_platform
176177

177178
if envs.FD_MOE_MXFP4_BACKEND is not None:
178179
from .mxfp4 import MXFP4Config
@@ -196,4 +197,9 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
196197
if quantization == "modelopt_fp4":
197198
method_to_config["modelopt_fp4"] = ModelOptNvFp4Config
198199

200+
# For XPU platform, use XPUKvCacheQuantConfig instead of KvCacheQuantConfig
201+
if quantization == "kvcache" and current_platform.is_xpu():
202+
from ..backends.xpu.quantization.kv_cache import XPUKvCacheQuantConfig
203+
method_to_config["kvcache"] = XPUKvCacheQuantConfig
204+
199205
return method_to_config[quantization]

fastdeploy/model_executor/models/ernie4_5_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def load_weights(self, weights_iterator) -> None:
597597
("attn.cache_k_scale", "cachek_matmul.in_scale", None, None),
598598
("attn.cache_v_scale", "cachev_matmul.in_scale", None, None),
599599
("up_gate_proj_in_scale", "up_gate_proj.in_scale", None, None),
600+
("down_proj_in_scale", "down_proj.in_scale", None, None),
600601
]
601602

602603
expert_params_mapping = []

0 commit comments

Comments
 (0)