Skip to content

Commit e985338

Browse files
committed
dataset refactor, add 350M config
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent ec7f70a commit e985338

7 files changed

Lines changed: 100 additions & 34 deletions

File tree

recipes/esm2_accelerate/dataset.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,59 +16,70 @@
1616
# Create the dataset -- here, we just use a simple parquet file with some raw protein sequences
1717
# stored in the repo itself to avoid external dependencies.
1818

19-
from pathlib import Path
20-
21-
from datasets import load_dataset
19+
from datasets import IterableDataset, load_dataset
2220
from transformers import AutoTokenizer
2321
from transformers.data.data_collator import DataCollatorForLanguageModeling
2422

2523

26-
def infinite_dataloader(dataloader, sampler):
27-
"""Create an infinite iterator that automatically restarts at the end of each epoch."""
28-
epoch = 0
29-
while True:
30-
sampler.set_epoch(epoch) # Update epoch for proper shuffling
31-
for batch in dataloader:
32-
yield batch
33-
epoch += 1 # Increment epoch counter after completing one full pass
34-
35-
36-
def create_datasets_and_collator(tokenizer_name: str, max_length: int = 1024):
37-
"""Create a dataloader for the dataset.
24+
def create_datasets_and_collator(
25+
tokenizer_name: str,
26+
train_load_dataset_kwargs: dict,
27+
eval_load_dataset_kwargs: dict,
28+
max_seq_length: int = 1024,
29+
truncate_eval_dataset: int | None = None,
30+
):
31+
"""Create datasets and a data collator to pass to the huggingface trainer.
3832
3933
Args:
4034
tokenizer_name: The name of the tokenizer to pull from the HuggingFace Hub.
41-
max_length: The maximum length of the protein sequences.
35+
train_load_dataset_kwargs: Keyword arguments to pass to `load_dataset` for the train dataset.
36+
eval_load_dataset_kwargs: Keyword arguments to pass to `load_dataset` for the eval dataset.
37+
max_seq_length: The maximum length of the protein sequences.
38+
truncate_eval_dataset: If not `None`, the eval dataset will be truncated to this number of examples.
39+
40+
This assumes that the dataset has a "sequence" column that will be tokenized.
4241
4342
Returns:
4443
Tuple of (train_dataset, eval_dataset, data_collator).
4544
"""
46-
# We copy this parquet file to the container to avoid external dependencies, modify if you're
47-
# using a local dataset. If you're reading this and scaling up the dataset to a larger size,
48-
# look into `set_transform` and other streaming options from the `datasets` library.
49-
data_path = Path(__file__).parent / "train.parquet"
50-
train_dataset = load_dataset("parquet", data_files=data_path.as_posix(), split="train")
51-
eval_dataset = train_dataset.select(range(10))
45+
train_dataset = load_dataset(**train_load_dataset_kwargs)
46+
eval_dataset = load_dataset(**eval_load_dataset_kwargs)
47+
if truncate_eval_dataset is not None:
48+
if isinstance(eval_dataset, IterableDataset):
49+
raise ValueError(
50+
"Cannot truncate an IterableDataset, don't use streaming datasets for eval if you want to truncate."
51+
)
52+
eval_dataset = eval_dataset.select(range(truncate_eval_dataset))
5253

5354
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
5455

55-
def tokenize_function(examples):
56+
def tokenize_function(sequence):
5657
"""Tokenize the protein sequences."""
5758
return tokenizer(
58-
examples["sequence"],
59+
sequence,
5960
truncation=True,
6061
padding="max_length",
61-
max_length=max_length,
62+
max_length=max_seq_length,
6263
return_tensors="pt",
6364
)
6465

65-
for dataset in [train_dataset, eval_dataset]:
66-
dataset.set_transform(tokenize_function)
66+
train_dataset = train_dataset.map(
67+
tokenize_function,
68+
batched=True,
69+
input_columns=["sequence"],
70+
remove_columns=train_dataset.column_names,
71+
)
72+
eval_dataset = eval_dataset.map(
73+
tokenize_function,
74+
batched=True,
75+
input_columns=["sequence"],
76+
remove_columns=eval_dataset.column_names,
77+
)
6778

6879
data_collator = DataCollatorForLanguageModeling(
6980
tokenizer=tokenizer,
7081
mlm_probability=0.15,
71-
pad_to_multiple_of=max_length,
82+
pad_to_multiple_of=max_seq_length,
7283
)
7384

7485
return train_dataset, eval_dataset, data_collator

recipes/esm2_accelerate/hydra_config/L0_sanity.yaml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@ defaults:
55
model_tag: "nvidia/esm2_t6_8M_UR50D"
66
stop_after_n_steps: 250
77

8+
dataset:
9+
tokenizer_name: ${model_tag}
10+
max_seq_length: 1024
11+
train_load_dataset_kwargs:
12+
path: "parquet"
13+
split: "train"
14+
data_files: "train.parquet"
15+
streaming: True
16+
eval_load_dataset_kwargs:
17+
path: "parquet"
18+
split: "train"
19+
data_files: "train.parquet"
20+
truncate_eval_dataset: 10
21+
822
trainer:
923
run_name: "esm2_t6_8M_UR50D_sanity"
1024
per_device_train_batch_size: 2
@@ -13,5 +27,5 @@ trainer:
1327
eval_steps: 1000
1428
logging_steps: 10
1529
report_to: "none"
16-
dataloader_num_workers: 4
30+
dataloader_num_workers: 1
1731
warmup_steps: 0

recipes/esm2_accelerate/hydra_config/L0_sanity_amplify.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@ defaults:
55
model_tag: "nvidia/AMPLIFY_120M"
66
stop_after_n_steps: 250
77

8+
dataset:
9+
tokenizer_name: ${model_tag}
10+
max_seq_length: 1024
11+
train_load_dataset_kwargs:
12+
path: "parquet"
13+
split: "train"
14+
data_files: "train.parquet"
15+
streaming: True
16+
eval_load_dataset_kwargs:
17+
path: "parquet"
18+
split: "train"
19+
data_files: "train.parquet"
20+
truncate_eval_dataset: 10
21+
822
trainer:
923
run_name: "amplify_120M_sanity"
1024
per_device_train_batch_size: 2
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
defaults:
2+
- defaults_amplify
3+
- _self_
4+
5+
stop_after_n_steps: 20_000
6+
trainer:
7+
run_name: "L1-350M-partial-conv"
8+
eval_steps: 1_000
9+
save_steps: 1_000
10+
logging_steps: 10
11+
report_to: "wandb"
12+
per_device_train_batch_size: 128
13+
per_device_eval_batch_size: 256

recipes/esm2_accelerate/hydra_config/L1_15B_perf_test.yaml renamed to recipes/esm2_accelerate/hydra_config/L1_esm2_15B_perf_test.yaml

File renamed without changes.

recipes/esm2_accelerate/hydra_config/defaults.yaml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
model_tag: "nvidia/esm2_t6_8M_UR50D"
22
stop_after_n_steps: 500_000
3-
max_seq_length: 1024
3+
4+
dataset:
5+
tokenizer_name: ${model_tag}
6+
max_seq_length: 1024
7+
# TODO(BIONEMO-2783): Replace this with our ESM-2 parquet dataset when it's ready.
8+
train_load_dataset_kwargs:
9+
path: "chandar-lab/UR100P"
10+
split: "train"
11+
revision: "refs/convert/parquet"
12+
streaming: True
13+
eval_load_dataset_kwargs:
14+
path: "chandar-lab/UR100P"
15+
split: "test"
16+
revision: "refs/convert/parquet"
17+
# Whether to truncate the eval dataset; HF Trainer will run the full eval dataset each eval step.
18+
# If set to an integer, the eval dataset will be truncated to that number of examples.
19+
truncate_eval_dataset: null
20+
421
trainer:
522
output_dir: "results"
623
run_name: ???

recipes/esm2_accelerate/train.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def main(args: DictConfig):
4848
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True)
4949
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, dtype=torch.bfloat16)
5050

51-
train_dataset, eval_dataset, data_collator = create_datasets_and_collator(
52-
tokenizer_name=args.model_tag,
53-
max_length=args.max_seq_length,
54-
)
51+
train_dataset, eval_dataset, data_collator = create_datasets_and_collator(**args.dataset)
5552

5653
training_args = TrainingArguments(**args.trainer)
5754

0 commit comments

Comments
 (0)