Skip to content

Commit f93cbdb

Browse files
authored
[npu] npu qwen3.5 megatron padding_free fix (#50)
1 parent bb0dc5b commit f93cbdb

2 files changed

Lines changed: 15 additions & 7 deletions

File tree

src/mcore_bridge/model/gpts/qwen3_next.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
1818
from megatron.core.utils import deprecate_inference_params, is_fa_min_version
1919
from packaging import version
20+
from transformers.utils import is_torch_npu_available
2021
from typing import Optional, Tuple, Union
2122

2223
from mcore_bridge.bridge import GPTBridge
@@ -58,6 +59,17 @@
5859
logger = get_logger()
5960

6061

62+
def resolve_gdn_attention_mask(kwargs) -> Optional[torch.Tensor]:
63+
if is_torch_npu_available():
64+
attention_mask = kwargs.get('attention_mask_2d')
65+
if attention_mask is not None:
66+
return attention_mask.to(torch.bool)
67+
attention_mask = kwargs.get('attention_mask')
68+
if attention_mask is None:
69+
return None
70+
return (~attention_mask).sum(dim=(1, 2)) > 0
71+
72+
6173
class Qwen3NextRMSNorm(torch.nn.Module):
6274
"""
6375
Zero-Centered RMSNorm for Qwen3-Next.
@@ -485,9 +497,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
485497
hidden_states = new_hidden_states
486498
else:
487499
hidden_states = hidden_states.transpose(0, 1)
488-
attention_mask = kwargs.get('attention_mask')
489-
if attention_mask is not None:
490-
attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0
500+
attention_mask = resolve_hf_attention_mask(kwargs)
491501
res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask)
492502
if thd_format:
493503
res = res[attention_mask][:, None]

src/mcore_bridge/model/mm_gpts/qwen3_5.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mcore_bridge.utils import get_env_args
1111

1212
from ..constant import ModelType
13-
from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader
13+
from ..gpts.qwen3_next import Qwen3NextBridge, Qwen3NextLoader, resolve_gdn_attention_mask
1414
from ..register import ModelMeta, register_model
1515
from .utils import HuggingFaceVit
1616

@@ -52,9 +52,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
5252
hidden_states = new_hidden_states
5353
else:
5454
hidden_states = hidden_states.transpose(0, 1)
55-
attention_mask = kwargs.get('attention_mask')
56-
if attention_mask is not None:
57-
attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0
55+
attention_mask = resolve_gdn_attention_mask(kwargs)
5856
res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask)
5957
if thd_format:
6058
res = res[attention_mask][:, None]

0 commit comments

Comments
 (0)