|
1 | 1 | from typing import List, Optional, Tuple, Union |
2 | 2 |
|
3 | 3 | import torch |
| 4 | +import torch.distributed as dist |
4 | 5 | from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( |
5 | 6 | Qwen3OmniMoeThinkerCausalLMOutputWithPast, |
6 | 7 | Qwen3OmniMoeThinkerForConditionalGeneration, |
|
11 | 12 | from lmms_engine.parallel.sequence_parallel.ulysses import ( |
12 | 13 | calculate_seq_len_per_rank, |
13 | 14 | gather_outputs_and_unpad, |
| 15 | + get_ulysses_sequence_parallel_group, |
14 | 16 | get_ulysses_sequence_parallel_world_size, |
15 | 17 | pad_to_max_across_ranks, |
16 | 18 | slice_input_tensor, |
@@ -266,7 +268,14 @@ def lce_forward( |
266 | 268 | # Pad to max size across ranks, then gather and unpad |
267 | 269 | loss, total_padding = pad_to_max_across_ranks(loss, dim=0) |
268 | 270 | loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding) |
269 | | - loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8) |
| 271 | + # Calculate the actual number of valid tokens (non-ignored labels) across all ranks |
| 272 | + # shift_labels shape is (num_tokens,) after flatten, -100 means ignore |
| 273 | + num_valid_tokens = (shift_labels != -100).sum().float() |
| 274 | + # Gather num_valid_tokens across all SP ranks to get the total count |
| 275 | + sp_group = get_ulysses_sequence_parallel_group() |
| 276 | + if sp_group is not None: |
| 277 | + dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group) |
| 278 | + loss = torch.sum(loss) / (num_valid_tokens + 1e-8) |
270 | 279 |
|
271 | 280 | if reduction == "sum": |
272 | 281 | loss /= kwargs["num_items_in_batch"] |
|
0 commit comments