|
16 | 16 | # Create the dataset -- here, we just use a simple parquet file with some raw protein sequences |
17 | 17 | # stored in the repo itself to avoid external dependencies. |
18 | 18 |
|
19 | | -from pathlib import Path |
20 | | - |
21 | | -from datasets import load_dataset |
| 19 | +from datasets import IterableDataset, load_dataset |
22 | 20 | from transformers import AutoTokenizer |
23 | 21 | from transformers.data.data_collator import DataCollatorForLanguageModeling |
24 | 22 |
|
25 | 23 |
|
26 | | -def infinite_dataloader(dataloader, sampler): |
27 | | - """Create an infinite iterator that automatically restarts at the end of each epoch.""" |
28 | | - epoch = 0 |
29 | | - while True: |
30 | | - sampler.set_epoch(epoch) # Update epoch for proper shuffling |
31 | | - for batch in dataloader: |
32 | | - yield batch |
33 | | - epoch += 1 # Increment epoch counter after completing one full pass |
34 | | - |
35 | | - |
36 | | -def create_datasets_and_collator(tokenizer_name: str, max_length: int = 1024): |
37 | | - """Create a dataloader for the dataset. |
| 24 | +def create_datasets_and_collator( |
| 25 | + tokenizer_name: str, |
| 26 | + train_load_dataset_kwargs: dict, |
| 27 | + eval_load_dataset_kwargs: dict, |
| 28 | + max_seq_length: int = 1024, |
| 29 | + truncate_eval_dataset: int | None = None, |
| 30 | +): |
| 31 | + """Create datasets and a data collator to pass to the huggingface trainer. |
38 | 32 |
|
39 | 33 | Args: |
40 | 34 | tokenizer_name: The name of the tokenizer to pull from the HuggingFace Hub. |
41 | | - max_length: The maximum length of the protein sequences. |
| 35 | + train_load_dataset_kwargs: Keyword arguments to pass to `load_dataset` for the train dataset. |
| 36 | + eval_load_dataset_kwargs: Keyword arguments to pass to `load_dataset` for the eval dataset. |
| 37 | + max_seq_length: The maximum length of the protein sequences. |
| 38 | + truncate_eval_dataset: If not `None`, the eval dataset will be truncated to this number of examples. |
| 39 | +
|
| 40 | + This assumes that the dataset has a "sequence" column that will be tokenized. |
42 | 41 |
|
43 | 42 | Returns: |
44 | 43 | Tuple of (train_dataset, eval_dataset, data_collator). |
45 | 44 | """ |
46 | | - # We copy this parquet file to the container to avoid external dependencies, modify if you're |
47 | | - # using a local dataset. If you're reading this and scaling up the dataset to a larger size, |
48 | | - # look into `set_transform` and other streaming options from the `datasets` library. |
49 | | - data_path = Path(__file__).parent / "train.parquet" |
50 | | - train_dataset = load_dataset("parquet", data_files=data_path.as_posix(), split="train") |
51 | | - eval_dataset = train_dataset.select(range(10)) |
| 45 | + train_dataset = load_dataset(**train_load_dataset_kwargs) |
| 46 | + eval_dataset = load_dataset(**eval_load_dataset_kwargs) |
| 47 | + if truncate_eval_dataset is not None: |
| 48 | + if isinstance(eval_dataset, IterableDataset): |
| 49 | + raise ValueError( |
| 50 | + "Cannot truncate an IterableDataset, don't use streaming datasets for eval if you want to truncate." |
| 51 | + ) |
| 52 | + eval_dataset = eval_dataset.select(range(truncate_eval_dataset)) |
52 | 53 |
|
53 | 54 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
54 | 55 |
|
55 | | - def tokenize_function(examples): |
| 56 | + def tokenize_function(sequence): |
56 | 57 | """Tokenize the protein sequences.""" |
57 | 58 | return tokenizer( |
58 | | - examples["sequence"], |
| 59 | + sequence, |
59 | 60 | truncation=True, |
60 | 61 | padding="max_length", |
61 | | - max_length=max_length, |
| 62 | + max_length=max_seq_length, |
62 | 63 | return_tensors="pt", |
63 | 64 | ) |
64 | 65 |
|
65 | | - for dataset in [train_dataset, eval_dataset]: |
66 | | - dataset.set_transform(tokenize_function) |
| 66 | + train_dataset = train_dataset.map( |
| 67 | + tokenize_function, |
| 68 | + batched=True, |
| 69 | + input_columns=["sequence"], |
| 70 | + remove_columns=train_dataset.column_names, |
| 71 | + ) |
| 72 | + eval_dataset = eval_dataset.map( |
| 73 | + tokenize_function, |
| 74 | + batched=True, |
| 75 | + input_columns=["sequence"], |
| 76 | + remove_columns=eval_dataset.column_names, |
| 77 | + ) |
67 | 78 |
|
68 | 79 | data_collator = DataCollatorForLanguageModeling( |
69 | 80 | tokenizer=tokenizer, |
70 | 81 | mlm_probability=0.15, |
71 | | - pad_to_multiple_of=max_length, |
| 82 | + pad_to_multiple_of=max_seq_length, |
72 | 83 | ) |
73 | 84 |
|
74 | 85 | return train_dataset, eval_dataset, data_collator |
0 commit comments