Skip to content

Commit f4eda5a

Browse files
authored
[Models] Update SWA RoPE theta for MLA/GQA attention (#8077)
* update mla_gqa_swa_rope_theta * update mla_gqa_swa_rope_theta * update mla_gqa_swa_rope_theta1
1 parent 6d9a8f4 commit f4eda5a

5 files changed

Lines changed: 57 additions & 8 deletions

File tree

fastdeploy/model_executor/forward_meta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class ForwardMeta:
6969
ids_remove_padding: paddle.Tensor
7070
# Rotation position embedding
7171
rotary_embs: Optional[paddle.Tensor] = None
72+
swa_rotary_embs: Optional[paddle.Tensor] = None
7273

7374
# Use cuda graph in this step or not. Used to avoid run cuda graph when in dummy run or prefill stage.
7475
step_use_cudagraph: bool = False

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
self.sink_size: int = getattr(fd_config.model_config, "sink_size", 0)
186186
self.window_attn_skip_freq: list = getattr(fd_config.model_config, "window_attn_skip_freq", [0])
187187
self.head_wise_swa_ratio: float = getattr(fd_config.model_config, "head_wise_swa_ratio", 0.0)
188+
self.swa_rope_theta = getattr(fd_config.model_config, "swa_rope_theta", None)
188189

189190
self.head_wise_full_hidden = 0
190191
if self.head_wise_swa_ratio > 0.0:
@@ -320,8 +321,11 @@ def forward_mixed(
320321
forward_meta.rotary_embs = self._get_identity_rotary_embs(forward_meta.rotary_embs)
321322

322323
sliding_window = 0
324+
rotary_embs = forward_meta.rotary_embs
323325
if len(self.window_attn_skip_freq) > 1 and self.window_attn_skip_freq[layer.layer_id] == 1:
324326
sliding_window = self.sliding_window if self.sliding_window > 0 else layer.sliding_window
327+
if self.swa_rope_theta is not None:
328+
rotary_embs = forward_meta.swa_rotary_embs
325329

326330
norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
327331
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
@@ -401,8 +405,8 @@ def forward_mixed(
401405
assert forward_meta.rotary_embs.shape[0] == 2
402406
do_rope(
403407
qkv,
404-
forward_meta.rotary_embs[0],
405-
forward_meta.rotary_embs[1],
408+
rotary_embs[0],
409+
rotary_embs[1],
406410
forward_meta.cu_seqlens_q,
407411
forward_meta.seq_lens_decoder,
408412
forward_meta.batch_id_per_token,
@@ -476,7 +480,7 @@ def forward_mixed(
476480
forward_meta.decoder_num_blocks_cpu,
477481
forward_meta.max_len_tensor_cpu,
478482
res,
479-
forward_meta.rotary_embs,
483+
rotary_embs,
480484
forward_meta.attn_mask,
481485
layer.qkv_bias,
482486
layer.qkv_scale,
@@ -532,7 +536,7 @@ def forward_mixed(
532536
forward_meta.decoder_tile_ids_per_batch,
533537
forward_meta.decoder_num_blocks_cpu,
534538
forward_meta.max_len_tensor_cpu,
535-
forward_meta.rotary_embs,
539+
rotary_embs,
536540
forward_meta.attn_mask,
537541
layer.qkv_bias,
538542
layer.qkv_scale,

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,21 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
299299
self.kv_lora_rank = fd_config.model_config.kv_lora_rank
300300

301301
# swa
302-
self.swa_layer_list = getattr(fd_config.model_config, "window_attn_skip_freq", None)
302+
self.window_attn_skip_freq = getattr(fd_config.model_config, "window_attn_skip_freq", None)
303303
self.sliding_window = getattr(fd_config.model_config, "sliding_window", 0)
304+
self.swa_rope_theta = getattr(fd_config.model_config, "swa_rope_theta", None)
304305

305306
self.attn_softmax_scale = self.qk_head_dim**-0.5
306307

307308
if fd_config.model_config.model_type == "glm_moe_dsa":
308309
self.rope_theta = fd_config.model_config.rope_parameters["rope_theta"]
310+
311+
if (
312+
self.window_attn_skip_freq is not None
313+
and self.window_attn_skip_freq[self.layer_id] == 1
314+
and self.swa_rope_theta is not None
315+
):
316+
self.rope_theta = self.swa_rope_theta
309317
else:
310318
self.rope_theta = fd_config.model_config.rope_theta
311319

@@ -525,9 +533,7 @@ def forward(
525533
need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0
526534
need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0
527535

528-
window_attn_skip_freq = getattr(self.fd_config.model_config, "window_attn_skip_freq", None)
529-
530-
if window_attn_skip_freq is not None and window_attn_skip_freq[self.layer_id] == 1:
536+
if self.window_attn_skip_freq is not None and self.window_attn_skip_freq[self.layer_id] == 1:
531537
attn_out = self.forward_swa_static(
532538
forward_meta=forward_meta,
533539
query_nope=query_nope,

fastdeploy/worker/gpu_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,6 +1466,7 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False):
14661466
self.forward_meta = ForwardMeta(
14671467
ids_remove_padding=self.share_inputs["ids_remove_padding"],
14681468
rotary_embs=self.share_inputs["rope_emb"],
1469+
swa_rotary_embs=self.share_inputs["swa_rope_emb"],
14691470
attn_backend=self.attn_backends[0],
14701471
attn_backends=self.attn_backends,
14711472
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],

fastdeploy/worker/input_batch.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def __init__(self, fd_config: FDConfig) -> None:
108108
else:
109109
self.max_chunk_tokens = self.fd_config.get_max_chunk_tokens(self.model_config.mm_max_tokens_per_item)
110110

111+
self.swa_rope_theta = getattr(self.fd_config.model_config, "swa_rope_theta", None)
112+
self.swa_rope_emb = None
113+
111114
def init_share_inputs(self):
112115
max_num_seqs = self.scheduler_config.max_num_seqs
113116

@@ -245,6 +248,14 @@ def init_share_inputs(self):
245248
model_config=self.model_config,
246249
partial_rotary_factor=self.model_config.partial_rotary_factor,
247250
)
251+
if self.swa_rope_theta is not None:
252+
self.swa_rope_emb = get_rope(
253+
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
254+
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
255+
base=self.swa_rope_theta,
256+
model_config=self.model_config,
257+
partial_rotary_factor=self.model_config.partial_rotary_factor,
258+
)
248259
if self.is_mm_model:
249260
self.image_features = None
250261
self.image_grid_thws = None
@@ -727,6 +738,14 @@ def reset_share_inputs(self):
727738
model_config=self.model_config,
728739
partial_rotary_factor=self.model_config.partial_rotary_factor,
729740
)
741+
if self.swa_rope_theta is not None:
742+
self.swa_rope_emb = get_rope(
743+
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
744+
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
745+
base=self.swa_rope_theta,
746+
model_config=self.model_config,
747+
partial_rotary_factor=self.model_config.partial_rotary_factor,
748+
)
730749
if self.is_mm_model:
731750
self.image_features = None
732751
self.image_grid_thws = None
@@ -764,6 +783,8 @@ def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) ->
764783
self.cache_config: CacheConfig = fd_config.cache_config
765784
self.speculative_config: SpeculativeConfig = fd_config.speculative_config
766785
self.enable_pd_reorder: bool = False
786+
self.swa_rope_theta = getattr(self.fd_config.model_config, "swa_rope_theta", None)
787+
self.swa_rope_emb = None
767788

768789
def init_share_inputs(self):
769790
# share with targe model
@@ -829,6 +850,14 @@ def init_share_inputs(self):
829850
model_config=self.model_config,
830851
partial_rotary_factor=self.model_config.partial_rotary_factor,
831852
)
853+
if self.swa_rope_theta is not None:
854+
self.swa_rope_emb = get_rope(
855+
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
856+
position_ids=tmp_position_ids,
857+
base=self.swa_rope_theta,
858+
model_config=self.model_config,
859+
partial_rotary_factor=self.model_config.partial_rotary_factor,
860+
)
832861

833862
# self.caches = self.cache_kvs
834863
# Inherit generation hyperparameters from the main model for consistency
@@ -1059,6 +1088,14 @@ def reset_model_inputs(self) -> None:
10591088
model_config=self.model_config,
10601089
partial_rotary_factor=self.model_config.partial_rotary_factor,
10611090
)
1091+
if self.swa_rope_theta is not None:
1092+
self.swa_rope_emb = get_rope(
1093+
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
1094+
position_ids=tmp_position_ids,
1095+
base=self.swa_rope_theta,
1096+
model_config=self.model_config,
1097+
partial_rotary_factor=self.model_config.partial_rotary_factor,
1098+
)
10621099

10631100
# Reset generation hyperparameters from the main model
10641101
self.top_p = self.target_model_input_batch["top_p"]

0 commit comments

Comments
 (0)