Skip to content

loss_mask not adjusted after target shift, causing loss on invalid padding #317

@oluxembourg

Description

@oluxembourg

When preparing data for training, the code shifts target (logits) and input_ids to the left using padding(tensor, left=False):

target = outs.logits
target = padding(target, left=False)
input_ids = padding(input_ids, left=False)
if target is not None:
target = target.to(device)
loss_mask = loss_mask[..., None]
loss_mask = loss_mask.to(device)

This left-shift is done for next-token prediction alignment. However, after shifting, the last position contains a zero-padding value (not a real token). If loss_mask still has 1 at the last position, the model computes loss on a meaningless padded target, potentially degrading training quality.

Suggested fix

Exclude the last position from loss computation in the dataprepare function:

target = outs.logits
target = padding(target, left=False)
input_ids = padding(input_ids, left=False)

if target is not None:
    target = target.to(device)
    loss_mask[..., -1] = 0  # Exclude last position (now contains padding)
    loss_mask = loss_mask[..., None]
    loss_mask = loss_mask.to(device)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions