Skip to content

Batchwise padding dataset #121

@mrghofrani

Description

@mrghofrani

Hello
I'm pretty new to Pytorch so sorry if this question was so simple. Because of memory limits, I can't pad my dataset as a whole. So I was wondering what is the simplest way to move the pad_dataset function into the training process, I mean how can I pad the dataset in a batch? For ease of reference, I added the pad_dataset below.
Thanks.

def pad_dataset(dataset, padding=0):
    """ Pad the dataset. This could be optimized by defining a Dataset class and padding at the batch level, but this is simpler. """
    max_l = max(len(x) for x in dataset["input_ids"])
    for name in PADDED_INPUTS:
        dataset[name] = [x + [padding if name != "lm_labels" else -100] * (max_l - len(x)) for x in dataset[name]]
    return dataset

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions