Skip to content

Pad each batch, not the whole dataset#30

Open
sshleifer wants to merge 16 commits intohuggingface:masterfrom
sshleifer:batch-padding
Open

Pad each batch, not the whole dataset#30
sshleifer wants to merge 16 commits intohuggingface:masterfrom
sshleifer:batch-padding

Conversation

@sshleifer
Copy link
Copy Markdown
Contributor

@sshleifer sshleifer commented Sep 23, 2019

Previously, each sequence was padded to the length of the longest sequence in the dataset.
In this PR, each batch is padded to the length of the longest sequence in the batch. This results in a 30% speedup with negligible impact on metrics.

Code Changes

  • ChatDataset yields example dicts like {'input_ids': [[hist + cand1], ..[hist +cand_n]],} for the PADDED_INPUTS and mc_token_ids and mc_labels in the same format as previously.
  • ChatDataset().collate_fn(examples: list) turns a list of example dicts into the list of 5 tensors by batching them and padding them
  • As a result, get_dataloaders does much less
  • There is a data format change to the part of the process where we make lists of examples to facilitate this.
  • convai_evaluation.py still calls the old pad_dataset

1 Epoch Sanity Check

Before Change: 85 minutes
Validation: {'accuracy': 0.7483655941545956,
'average_accuracy': 0.7483655941545956,
'average_nll': 2.6815188920676687,
'average_ppl': 14.607263311061963,
'nll': 2.6815188920676687}

After Change: 60 minutes
Validation: {'accuracy': 0.7466991411357519,
'average_accuracy': 0.7466991411357519,
'average_nll': 2.6821035040007972,
'average_ppl': 14.615805388160778,
'nll': 2.6821035040007972}

Command:

python train.py --model_checkpoint openai-gpt --dataset_cache dataset_cache --fp16 O1 --n_epochs 1 --train_batch_size 4

@sshleifer sshleifer changed the title (WIP) Pad each batch, not the whole dataset Pad each batch, not the whole dataset Sep 29, 2019
Comment thread train.py
return train_loader, valid_loader, train_sampler, valid_sampler


def make_data_lists(args, personachat, tokenizer):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

Comment thread train.py
for utterance in dialog["utterances"]:
history = utterance["history"][-(2*args.max_history+1):]
candidate_instances = defaultdict(list)
history = utterance["history"][-(2 * args.max_history + 1):]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add assert len(utterance['candidates']) >= num_candidates

Comment thread train.py
return instance, sequence # TODO: second arg is never used, delete it


def pad_and_tensorize(batch_dict, padding):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this and ChatDataset should be easy to unit test

Comment thread train.py
valid_dataset = ChatDataset(datasets['valid'], pad_id)

logger.info("Build train and validation dataloaders")
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(maybe) put this in ChatDataset.to_loader(self, args, shuffle) -> sampler, loader

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at some point might also want to document which tensors are 3D

Comment thread train.py
for input_name, input_array in instance.items():
datasets[dataset_name][input_name].append(input_array)
candidate_instances[input_name].append(input_array)
for k in candidate_instances.keys():
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.items() will save some chars

Comment thread train.py Outdated
for j, candidate in enumerate(utterance["candidates"][-num_candidates:]):
lm_labels = bool(j == num_candidates-1)
instance, _ = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels)
lm_labels = bool(j == num_candidates - 1)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better varname?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant