Skip to content

Restart training with new data, mid-epoch #436

@schopra8

Description

@schopra8

🚀 Feature

Right now, you can change the datasets used in a CombinedStreamingDataset and resume training on epoch boundaries. It would be great if you could resume training with new datasets mid-epoch.

Motivation

When we're doing curriculum learning, we don't know the right number of steps or epochs to train. If we reach a sufficient validation loss, we kill the training and resume training with a new group of datasets (i.e. adjust the curriculum). Accordingly, we often have to kill training mid-epoch and restart with new datasets.

Pitch

If you are training with N datasets and kill training K steps into epoch N, change the underlying datasets, and resume training from a checkpoint that was saved mid-epoch, the trainer should jump to epoch N + 1 with the new datasets, the old optimizer state, the correct global batch index.

Alternatives

Right now, we have two workarounds:

  • Delete the loops part of the last saved mid-epoch checkpoint, before we resume training with different datasets. This isn't a great solution - because you resume training at epoch 0, which messes with any learning rate schedulers we have.
  • Copy the loops part of the last saved epoch checkpoint into the more recent epoch that was saved mid-epoch, before we resume training with different datasets. This approximately mitigates the learning rate scheduler issue -- but isn't the cleanest solution and is a pain to do manually every time.

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions