Skip to content

Commit 1ae5268

Browse files
[release/1.1] Fix rope precision (PaddlePaddle#4121)
Co-authored-by: Xuxinyi <104957571+xuxinyi389@users.noreply.github.com>
1 parent 8821b39 commit 1ae5268

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

paddleformers/transformers/qwen3_vl/modeling_fleet.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,8 +748,13 @@ def forward(
748748
rotary_pos_emb = self.rot_pos_emb(grid_thw)
749749
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
750750
rotary_pos_emb = paddle.cat((rotary_pos_emb, rotary_pos_emb), axis=-1)
751-
rotary_pos_cos = rotary_pos_emb.cos()
752-
rotary_pos_sin = rotary_pos_emb.sin()
751+
# Cast freqs to float32 and compute cos/sin inside auto_cast(False) to match the
752+
# precision of _apply_rotary_pos_emb_bshd_fp32, which computes cos/sin on the same
753+
# bf16 freqs but under auto_cast(False) using a float32 kernel.
754+
with paddle.amp.auto_cast(False):
755+
_freqs_f32 = rotary_pos_emb.astype("float32")
756+
rotary_pos_cos = paddle.cos(_freqs_f32)
757+
rotary_pos_sin = paddle.sin(_freqs_f32)
753758
rotary_pos_emb = rotary_pos_emb[:, None, None, :]
754759
rotary_pos_emb = rotary_pos_emb.transpose([1, 0, 2, 3])
755760

0 commit comments

Comments
 (0)