Skip to content

Commit 45c1944

Browse files
authored
[fix]: use valid labels for SP loss normalization (#130)
* fix(liger): use valid labels for SP loss normalization Replace attention_mask sum with valid tokens (non-ignored labels) count for proper loss normalization in sequence parallel mode. * Lint
1 parent 0beb0a0 commit 45c1944

7 files changed

Lines changed: 70 additions & 7 deletions

File tree

src/lmms_engine/models/qwen2/qwen2_liger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
except:
1313
print("Liger Kernel is not installed, pip install liger-kernel to use this patch")
1414
import torch
15+
import torch.distributed as dist
1516

1617
from lmms_engine.parallel.sequence_parallel.ulysses import (
1718
calculate_seq_len_per_rank,
1819
gather_outputs_and_unpad,
20+
get_ulysses_sequence_parallel_group,
1921
get_ulysses_sequence_parallel_world_size,
2022
pad_to_max_across_ranks,
2123
slice_input_tensor,
@@ -143,7 +145,14 @@ def qwen2_lce_forward(
143145
# Pad to max size across ranks, then gather and unpad
144146
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
145147
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
146-
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
148+
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
149+
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
150+
num_valid_tokens = (shift_labels != -100).sum().float()
151+
# Gather num_valid_tokens across all SP ranks to get the total count
152+
sp_group = get_ulysses_sequence_parallel_group()
153+
if sp_group is not None:
154+
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
155+
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)
147156

148157
if reduction == "sum":
149158
loss /= loss_kwargs["num_items_in_batch"]

src/lmms_engine/models/qwen2_5_omni/qwen2_5_omni_liger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Optional, Tuple, Union
22

33
import torch
4+
import torch.distributed as dist
45
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
56
Qwen2_5OmniThinkerCausalLMOutputWithPast,
67
Qwen2_5OmniThinkerForConditionalGeneration,
@@ -10,6 +11,7 @@
1011
from lmms_engine.parallel.sequence_parallel.ulysses import (
1112
calculate_seq_len_per_rank,
1213
gather_outputs_and_unpad,
14+
get_ulysses_sequence_parallel_group,
1315
get_ulysses_sequence_parallel_world_size,
1416
pad_to_max_across_ranks,
1517
slice_input_tensor,
@@ -253,7 +255,14 @@ def lce_forward(
253255
# Pad to max size across ranks, then gather and unpad
254256
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
255257
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
256-
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
258+
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
259+
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
260+
num_valid_tokens = (shift_labels != -100).sum().float()
261+
# Gather num_valid_tokens across all SP ranks to get the total count
262+
sp_group = get_ulysses_sequence_parallel_group()
263+
if sp_group is not None:
264+
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
265+
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)
257266

258267
if reduction == "sum":
259268
loss /= kwargs["num_items_in_batch"]

src/lmms_engine/models/qwen2_5_vl/qwen2_5_vl_liger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Optional, Tuple, Union
22

33
import torch
4+
import torch.distributed as dist
45
from transformers import Qwen2_5_VLForConditionalGeneration
56
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
67
Qwen2_5_VLCausalLMOutputWithPast,
@@ -9,6 +10,7 @@
910
from lmms_engine.parallel.sequence_parallel.ulysses import (
1011
calculate_seq_len_per_rank,
1112
gather_outputs_and_unpad,
13+
get_ulysses_sequence_parallel_group,
1214
get_ulysses_sequence_parallel_world_size,
1315
pad_to_max_across_ranks,
1416
slice_input_tensor,
@@ -125,7 +127,14 @@ def lce_forward(
125127
# Pad to max size across ranks, then gather and unpad
126128
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
127129
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
128-
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
130+
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
131+
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
132+
num_valid_tokens = (shift_labels != -100).sum().float()
133+
# Gather num_valid_tokens across all SP ranks to get the total count
134+
sp_group = get_ulysses_sequence_parallel_group()
135+
if sp_group is not None:
136+
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
137+
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)
129138

130139
if reduction == "sum":
131140
loss /= kwargs["num_items_in_batch"]

src/lmms_engine/models/qwen3/qwen3_liger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
except:
1313
print("Liger Kernel is not installed, pip install liger-kernel to use this patch")
1414
import torch
15+
import torch.distributed as dist
1516

1617
from lmms_engine.parallel.sequence_parallel.ulysses import (
1718
calculate_seq_len_per_rank,
1819
gather_outputs_and_unpad,
20+
get_ulysses_sequence_parallel_group,
1921
get_ulysses_sequence_parallel_world_size,
2022
pad_to_max_across_ranks,
2123
slice_input_tensor,
@@ -143,7 +145,14 @@ def qwen3_lce_forward(
143145
# Pad to max size across ranks, then gather and unpad
144146
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
145147
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
146-
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
148+
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
149+
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
150+
num_valid_tokens = (shift_labels != -100).sum().float()
151+
# Gather num_valid_tokens across all SP ranks to get the total count
152+
sp_group = get_ulysses_sequence_parallel_group()
153+
if sp_group is not None:
154+
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
155+
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)
147156

148157
if reduction == "sum":
149158
loss /= loss_kwargs["num_items_in_batch"]

src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_liger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Optional, Tuple, Union
22

33
import torch
4+
import torch.distributed as dist
45
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
56
Qwen3OmniMoeThinkerCausalLMOutputWithPast,
67
Qwen3OmniMoeThinkerForConditionalGeneration,
@@ -11,6 +12,7 @@
1112
from lmms_engine.parallel.sequence_parallel.ulysses import (
1213
calculate_seq_len_per_rank,
1314
gather_outputs_and_unpad,
15+
get_ulysses_sequence_parallel_group,
1416
get_ulysses_sequence_parallel_world_size,
1517
pad_to_max_across_ranks,
1618
slice_input_tensor,
@@ -266,7 +268,14 @@ def lce_forward(
266268
# Pad to max size across ranks, then gather and unpad
267269
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
268270
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
269-
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
271+
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
272+
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
273+
num_valid_tokens = (shift_labels != -100).sum().float()
274+
# Gather num_valid_tokens across all SP ranks to get the total count
275+
sp_group = get_ulysses_sequence_parallel_group()
276+
if sp_group is not None:
277+
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
278+
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)
270279

271280
if reduction == "sum":
272281
loss /= kwargs["num_items_in_batch"]

src/lmms_engine/models/qwen3_vl/qwen3_vl_liger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Optional, Tuple, Union
22

33
import torch
4+
import torch.distributed as dist
45
from loguru import logger
56
from transformers import Qwen3VLForConditionalGeneration
67
from transformers.cache_utils import Cache
@@ -9,6 +10,7 @@
910
from lmms_engine.parallel.sequence_parallel.ulysses import (
1011
calculate_seq_len_per_rank,
1112
gather_outputs_and_unpad,
13+
get_ulysses_sequence_parallel_group,
1214
get_ulysses_sequence_parallel_world_size,
1315
pad_to_max_across_ranks,
1416
slice_input_tensor,
@@ -121,7 +123,14 @@ def qwen3_vl_lce_forward(
121123
# Pad to max size across ranks, then gather and unpad
122124
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
123125
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
124-
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
126+
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
127+
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
128+
num_valid_tokens = (shift_labels != -100).sum().float()
129+
# Gather num_valid_tokens across all SP ranks to get the total count
130+
sp_group = get_ulysses_sequence_parallel_group()
131+
if sp_group is not None:
132+
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
133+
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)
125134

126135
if reduction == "sum":
127136
loss /= kwargs["num_items_in_batch"]

src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_liger.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Optional, Tuple, Union
22

33
import torch
4+
import torch.distributed as dist
45
from transformers.cache_utils import Cache
56
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
67
Qwen3VLMoeCausalLMOutputWithPast,
@@ -11,6 +12,7 @@
1112
from lmms_engine.parallel.sequence_parallel.ulysses import (
1213
calculate_seq_len_per_rank,
1314
gather_outputs_and_unpad,
15+
get_ulysses_sequence_parallel_group,
1416
get_ulysses_sequence_parallel_world_size,
1517
pad_to_max_across_ranks,
1618
slice_input_tensor,
@@ -112,7 +114,14 @@ def lce_forward(
112114
# Pad to max size across ranks, then gather and unpad
113115
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
114116
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
115-
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
117+
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
118+
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
119+
num_valid_tokens = (shift_labels != -100).sum().float()
120+
# Gather num_valid_tokens across all SP ranks to get the total count
121+
sp_group = get_ulysses_sequence_parallel_group()
122+
if sp_group is not None:
123+
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
124+
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)
116125

117126
if reduction == "sum":
118127
loss /= kwargs["num_items_in_batch"]

0 commit comments

Comments
 (0)