diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index aa34784b2a5..2ab3bb6823f 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -236,8 +236,10 @@ def init_share_inputs(self): # Initialize rotary position embedding if not self.enable_mm: + rotary_percent = getattr(self.model_config, "rotary_percent", 1.0) + self.rotary_dim = int(self.model_config.head_dim * rotary_percent) self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, + rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim, position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)), base=self.model_config.rope_theta, model_config=self.model_config, @@ -716,8 +718,10 @@ def reset_share_inputs(self): fill_paddle_tensor(self, "attn_mask_offsets_full", -1) else: # Reset non-multimodal rope_emb + rotary_percent = getattr(self.model_config, "rotary_percent", 1.0) + self.rotary_dim = int(self.model_config.head_dim * rotary_percent) self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, + rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim, position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)), base=self.model_config.rope_theta, model_config=self.model_config, @@ -816,8 +820,10 @@ def init_share_inputs(self): tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) + rotary_percent = getattr(self.model_config, "rotary_percent", 1.0) + self.rotary_dim = int(self.model_config.head_dim * rotary_percent) self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, + rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim, position_ids=tmp_position_ids, base=self.model_config.rope_theta, model_config=self.model_config, @@ -1040,8 +1046,10 @@ def reset_model_inputs(self) -> None: # Reset rope embedding by recreating with default position_ids tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) + rotary_percent = getattr(self.model_config, "rotary_percent", 1.0) + self.rotary_dim = int(self.model_config.head_dim * rotary_percent) self.rope_emb = get_rope( - rotary_dim=self.model_config.head_dim, + rotary_dim=self.rotary_dim if rotary_percent < 1.0 else self.model_config.head_dim, position_ids=tmp_position_ids, base=self.model_config.rope_theta, model_config=self.model_config,