Skip to content

MASPlugin errors if SGD Loss is zero #1676

@man2machine

Description

@man2machine

In the MASPlugin in the before_backward callback there is a check to see if the loss has been generated from the SGDUpdate class's training_epoch function.

if not strategy.loss:
    raise ValueError("Loss is not available")

However, at times when dealing with a small amount of data in an experience, if the classifier is near perfect in its logits output, PyTorch may return tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>) for its loss output from nn.CrossEntropyLoss() due to numerical precision. Since a tensor with all zeroes evaluates to False when converted to a boolean, in this case the MAS algorithm errors even though the SGD update has actually occurred correctly.

Here are a few solutions:

  • Use a if not strategy.loss.requires_grad check instead of a if not strategy.loss check
  • Replace strategy.loss to be None initially, do not use self._make_empty_loss(), and in the MASPlugin check if strategy.loss is not None

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions