|
17 | 17 | from megatron.core.transformer.transformer_block import TransformerBlockSubmodules |
18 | 18 | from megatron.core.utils import deprecate_inference_params, is_fa_min_version |
19 | 19 | from packaging import version |
| 20 | +from transformers.utils import is_torch_npu_available |
20 | 21 | from typing import Optional, Tuple, Union |
21 | 22 |
|
22 | 23 | from mcore_bridge.bridge import GPTBridge |
|
58 | 59 | logger = get_logger() |
59 | 60 |
|
60 | 61 |
|
| 62 | +def resolve_gdn_attention_mask(kwargs) -> Optional[torch.Tensor]: |
| 63 | + if is_torch_npu_available(): |
| 64 | + attention_mask = kwargs.get('attention_mask_2d') |
| 65 | + if attention_mask is not None: |
| 66 | + return attention_mask.to(torch.bool) |
| 67 | + attention_mask = kwargs.get('attention_mask') |
| 68 | + if attention_mask is None: |
| 69 | + return None |
| 70 | + return (~attention_mask).sum(dim=(1, 2)) > 0 |
| 71 | + |
| 72 | + |
61 | 73 | class Qwen3NextRMSNorm(torch.nn.Module): |
62 | 74 | """ |
63 | 75 | Zero-Centered RMSNorm for Qwen3-Next. |
@@ -485,9 +497,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): |
485 | 497 | hidden_states = new_hidden_states |
486 | 498 | else: |
487 | 499 | hidden_states = hidden_states.transpose(0, 1) |
488 | | - attention_mask = kwargs.get('attention_mask') |
489 | | - if attention_mask is not None: |
490 | | - attention_mask = (~attention_mask).sum(dim=(1, 2)) > 0 |
| 500 | + attention_mask = resolve_hf_attention_mask(kwargs) |
491 | 501 | res = super().forward(hidden_states=hidden_states, attention_mask=attention_mask) |
492 | 502 | if thd_format: |
493 | 503 | res = res[attention_mask][:, None] |
|
0 commit comments