Skip to content

Commit de1c8bd

Browse files
authored
[Feature] Support partial rotary embedding (rotary_percent) for fleet-gqa-latent (#7955)
* support fleet-gqa-latent * update
1 parent 5b7ebb7 commit de1c8bd

1 file changed

Lines changed: 12 additions & 4 deletions

File tree

fastdeploy/worker/input_batch.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,10 @@ def init_share_inputs(self):
236236

237237
# Initialize rotary position embedding
238238
if not self.enable_mm:
239+
rotary_percent = getattr(self.model_config, "rotary_percent", 1.0)
240+
self.rotary_dim = int(self.model_config.head_dim * rotary_percent)
239241
self.rope_emb = get_rope(
240-
rotary_dim=self.model_config.head_dim,
242+
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
241243
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
242244
base=self.model_config.rope_theta,
243245
model_config=self.model_config,
@@ -716,8 +718,10 @@ def reset_share_inputs(self):
716718
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
717719
else:
718720
# Reset non-multimodal rope_emb
721+
rotary_percent = getattr(self.model_config, "rotary_percent", 1.0)
722+
self.rotary_dim = int(self.model_config.head_dim * rotary_percent)
719723
self.rope_emb = get_rope(
720-
rotary_dim=self.model_config.head_dim,
724+
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
721725
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
722726
base=self.model_config.rope_theta,
723727
model_config=self.model_config,
@@ -816,8 +820,10 @@ def init_share_inputs(self):
816820

817821
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
818822

823+
rotary_percent = getattr(self.model_config, "rotary_percent", 1.0)
824+
self.rotary_dim = int(self.model_config.head_dim * rotary_percent)
819825
self.rope_emb = get_rope(
820-
rotary_dim=self.model_config.head_dim,
826+
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
821827
position_ids=tmp_position_ids,
822828
base=self.model_config.rope_theta,
823829
model_config=self.model_config,
@@ -1040,8 +1046,10 @@ def reset_model_inputs(self) -> None:
10401046

10411047
# Reset rope embedding by recreating with default position_ids
10421048
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
1049+
rotary_percent = getattr(self.model_config, "rotary_percent", 1.0)
1050+
self.rotary_dim = int(self.model_config.head_dim * rotary_percent)
10431051
self.rope_emb = get_rope(
1044-
rotary_dim=self.model_config.head_dim,
1052+
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
10451053
position_ids=tmp_position_ids,
10461054
base=self.model_config.rope_theta,
10471055
model_config=self.model_config,

0 commit comments

Comments
 (0)