3232
3333def create_tokenized_dataset (
3434 distributed_config : DistributedConfig ,
35- tokenizer_path : str ,
35+ tokenizer_name_or_path : str ,
3636 load_dataset_kwargs : dict ,
3737 max_seq_length : int = 8192 ,
3838 stride : int = 200 ,
3939 buffer_size : int = 500_000 ,
4040 use_lazy_tokenization : bool = True ,
41- sequence_column : str = "sequence " ,
41+ text_column : str = "text " ,
4242):
4343 """Create a tokenized dataset with windowing.
4444
4545 Args:
4646 distributed_config: The distributed configuration.
47- tokenizer_path: Path to the nucleotide tokenizer directory.
47+ tokenizer_name_or_path: Name or path to the nucleotide tokenizer directory.
4848 load_dataset_kwargs: Keyword arguments to pass to `load_dataset`.
4949 max_seq_length: The maximum length of sequences (window size).
5050 stride: The stride for windowing (overlap = stride tokens).
5151 buffer_size: The buffer size for shuffle.
5252 use_lazy_tokenization: Whether to use datasets.set_transform for tokenization.
53- sequence_column : Name of the column containing genomic sequences (default: "sequence ").
53+ text_column : Name of the column containing genomic sequences (default: "text ").
5454
5555 Returns:
5656 Tuple of (tokenized_dataset, tokenizer).
@@ -67,13 +67,13 @@ def create_tokenized_dataset(
6767 )
6868 dataset = dataset .shuffle (seed = 42 , buffer_size = buffer_size )
6969
70- tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
70+ tokenizer = AutoTokenizer .from_pretrained (tokenizer_name_or_path )
7171
7272 def tokenize_with_windowing (examples ):
7373 """Tokenize nucleotide sequences with windowing (one-to-many mapping)."""
7474 # Tokenize with windowing using return_overflowing_tokens
7575 result = tokenizer (
76- examples [sequence_column ],
76+ examples [text_column ],
7777 max_length = max_seq_length ,
7878 stride = stride ,
7979 truncation = True ,
@@ -91,7 +91,7 @@ def tokenize_with_windowing(examples):
9191 # This causes dataset.column_names to be None for streaming IterableDataset.
9292 #
9393 # For IterableDataset with None column_names (OpenGenome2):
94- # - Must explicitly list columns to remove: [sequence_column , "record"]
94+ # - Must explicitly list columns to remove: [text_column , "record"]
9595 # - IterableDataset.map() handles missing columns gracefully
9696 #
9797 # For regular Dataset (non-streaming, or streaming with consistent schema like ESM2):
@@ -100,9 +100,9 @@ def tokenize_with_windowing(examples):
100100 #
101101 # TODO: Remove this workaround once Arc Institute fixes OpenGenome2 schema consistency.
102102 # When all shards have the same columns, dataset.column_names will work for both cases.
103- if isinstance (dataset , datasets .IterableDataset ):
103+ if isinstance (dataset , datasets .IterableDataset ) and dataset . column_names is None :
104104 # Streaming dataset: column_names may be None due to inconsistent schema
105- columns_to_remove = [sequence_column , "record" ]
105+ columns_to_remove = [text_column , "record" ]
106106 else :
107107 # Non-streaming dataset: use actual column names
108108 columns_to_remove = dataset .column_names
@@ -120,7 +120,7 @@ def tokenize_with_windowing(examples):
120120
121121def create_bshd_dataloader (
122122 distributed_config : DistributedConfig ,
123- tokenizer_path : str ,
123+ tokenizer_name_or_path : str ,
124124 load_dataset_kwargs : dict ,
125125 micro_batch_size : int ,
126126 num_workers : int = 1 ,
@@ -130,15 +130,15 @@ def create_bshd_dataloader(
130130 buffer_size : int = 500_000 ,
131131 use_lazy_tokenization : bool = True ,
132132 use_stateful_dataloader : bool = False ,
133- sequence_column : str = "sequence " ,
133+ text_column : str = "text " ,
134134 uppercase_labels : bool = False ,
135135 mask_degenerate_bases : bool = True ,
136136):
137137 """Create a BSHD dataloader for genomic sequences using CLM (causal language modeling).
138138
139139 Args:
140140 distributed_config: The distributed configuration.
141- tokenizer_path: Path to the nucleotide tokenizer directory.
141+ tokenizer_name_or_path: Name or path to the nucleotide tokenizer directory.
142142 load_dataset_kwargs: Keyword arguments to pass to `load_dataset`.
143143 micro_batch_size: The batch size per device.
144144 num_workers: The number of workers to use for the dataloader.
@@ -148,7 +148,7 @@ def create_bshd_dataloader(
148148 buffer_size: The buffer size for shuffle.
149149 use_lazy_tokenization: Whether to use datasets.set_transform for tokenization.
150150 use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state.
151- sequence_column : Name of the column containing genomic sequences (default: "sequence ").
151+ text_column : Name of the column containing genomic sequences (default: "text ").
152152 uppercase_labels: Whether to uppercase labels (genomic masking). Default: False.
153153 mask_degenerate_bases: Whether to mask non-ACGT bases (genomic masking). Default: False.
154154
@@ -157,13 +157,13 @@ def create_bshd_dataloader(
157157 """
158158 tokenized_dataset , tokenizer = create_tokenized_dataset (
159159 distributed_config = distributed_config ,
160- tokenizer_path = tokenizer_path ,
160+ tokenizer_name_or_path = tokenizer_name_or_path ,
161161 load_dataset_kwargs = load_dataset_kwargs ,
162162 max_seq_length = max_seq_length ,
163163 stride = stride ,
164164 buffer_size = buffer_size ,
165165 use_lazy_tokenization = use_lazy_tokenization ,
166- sequence_column = sequence_column ,
166+ text_column = text_column ,
167167 )
168168
169169 if isinstance (tokenized_dataset , datasets .IterableDataset ):
@@ -214,7 +214,7 @@ def create_bshd_dataloader(
214214
215215def create_thd_dataloader (
216216 distributed_config : DistributedConfig ,
217- tokenizer_path : str ,
217+ tokenizer_name_or_path : str ,
218218 load_dataset_kwargs : dict ,
219219 micro_batch_size : int | None = None ,
220220 token_micro_batch_size : int | None = None ,
@@ -224,15 +224,15 @@ def create_thd_dataloader(
224224 buffer_size : int = 500_000 ,
225225 use_lazy_tokenization : bool = True ,
226226 use_stateful_dataloader : bool = False ,
227- sequence_column : str = "sequence " ,
227+ text_column : str = "text " ,
228228 uppercase_labels : bool = False ,
229229 mask_degenerate_bases : bool = True ,
230230):
231231 """Create a dataloader that packs up to the maximum number of tokens per batch.
232232
233233 Args:
234234 distributed_config: The distributed configuration.
235- tokenizer_path: Path to the nucleotide tokenizer directory.
235+ tokenizer_name_or_path: Name or path to the nucleotide tokenizer directory.
236236 load_dataset_kwargs: Keyword arguments to pass to `load_dataset`.
237237 micro_batch_size: The batch size per device.
238238 token_micro_batch_size: The maximum number of tokens per batch. If None, the micro_batch_size * max_seq_length
@@ -244,22 +244,22 @@ def create_thd_dataloader(
244244 buffer_size: The buffer size for shuffle.
245245 use_lazy_tokenization: Whether to use datasets.set_transform for tokenization.
246246 use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state.
247- sequence_column : Name of the column containing genomic sequences (default: "sequence ").
247+ text_column : Name of the column containing genomic sequences (default: "text ").
248248 uppercase_labels: Whether to uppercase labels (genomic masking). Default: False.
249249 mask_degenerate_bases: Whether to mask degenerate bases (genomic masking). Default: True.
250250
251251 Returns:
252252 A dataloader that can be used for training.
253253 """
254- tokenized_dataset , tokenizer = create_tokenized_dataset (
254+ tokenized_dataset , _ = create_tokenized_dataset (
255255 distributed_config = distributed_config ,
256- tokenizer_path = tokenizer_path ,
256+ tokenizer_name_or_path = tokenizer_name_or_path ,
257257 load_dataset_kwargs = load_dataset_kwargs ,
258258 max_seq_length = max_seq_length ,
259259 stride = stride ,
260260 buffer_size = buffer_size ,
261261 use_lazy_tokenization = use_lazy_tokenization ,
262- sequence_column = sequence_column ,
262+ text_column = text_column ,
263263 )
264264
265265 assert isinstance (tokenized_dataset , datasets .IterableDataset ), "THD token packing requires a streaming dataset."
0 commit comments