Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions fastdeploy/worker/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,10 @@ def init_share_inputs(self):

# Initialize rotary position embedding
if not self.enable_mm:
rotary_percent = getattr(self.model_config, "rotary_percent", 1.0)
self.rotary_dim = int(self.model_config.head_dim * rotary_percent)
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 此处三目表达式冗余,可简化。

self.rotary_dim = int(head_dim * rotary_percent) 已在上一行计算完毕。当 rotary_percent == 1.0 时,int(head_dim * 1.0) == head_dim,两个分支等价。因此三目条件可直接简化为:

rotary_dim=self.rotary_dim,

下方三处(第 724、826、1052 行)存在同样的冗余,建议一并修改。

position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
base=self.model_config.rope_theta,
model_config=self.model_config,
Expand Down Expand Up @@ -716,8 +718,10 @@ def reset_share_inputs(self):
fill_paddle_tensor(self, "attn_mask_offsets_full", -1)
else:
# Reset non-multimodal rope_emb
rotary_percent = getattr(self.model_config, "rotary_percent", 1.0)
self.rotary_dim = int(self.model_config.head_dim * rotary_percent)
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)),
base=self.model_config.rope_theta,
model_config=self.model_config,
Expand Down Expand Up @@ -816,8 +820,10 @@ def init_share_inputs(self):

tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))

rotary_percent = getattr(self.model_config, "rotary_percent", 1.0)
self.rotary_dim = int(self.model_config.head_dim * rotary_percent)
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
Expand Down Expand Up @@ -1040,8 +1046,10 @@ def reset_model_inputs(self) -> None:

# Reset rope embedding by recreating with default position_ids
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
rotary_percent = getattr(self.model_config, "rotary_percent", 1.0)
self.rotary_dim = int(self.model_config.head_dim * rotary_percent)
self.rope_emb = get_rope(
rotary_dim=self.model_config.head_dim,
rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
Expand Down
Loading