Skip to content

Commit 07d6f33

Browse files
authored
Llama3 pre-context parallel dataloader changes (#1400)
Makes a number of updates in preparation for llama3 context-parallel training. It's still not currently working, need to further update the model to handle the `cu_seq_lens_q_padded` kwargs and would like to add a single-GPU CP test that uses BSHD inputs to at least exercise this code in CI. This PR: * Only materializes the dataloader on the cp_rank=0, and returns None on other ranks. * Uses the scatter operation in the dataloader to synchronize `StopIteration` exceptions * Adds tests for the CP dataloader on 1 and 2-gpu machines * moves llama3 to use DLCM data as the sanity dataset, turns off some genome collation options by default. This is larger than the dummy sequences currently used in training, and will make sure we can fill a few batches in CP testing. We may want to revert this eventually once we're done bringing up llama3; since it does trigger the tokenizer download during testing. * removes `lazy tokenization` from llama3, this wont work. See https://nvidia.slack.com/archives/C074Z808N05/p1767818883160949 * starts adding CP files for llama3 Closes BIO-11 --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 6321b0c commit 07d6f33

21 files changed

Lines changed: 763 additions & 253 deletions

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import logging
2222
from dataclasses import dataclass
23-
from typing import Any
23+
from typing import Any, TypedDict
2424

2525
import datasets
2626
import torch
@@ -334,7 +334,7 @@ class ContextParallelDataLoaderWrapper:
334334

335335
def __init__(
336336
self,
337-
dataloader: torch.utils.data.DataLoader,
337+
dataloader: torch.utils.data.DataLoader | None,
338338
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
339339
):
340340
"""A dataloader wrapper that distributes the data across the context parallelism group.
@@ -348,15 +348,28 @@ def __init__(
348348
cp_mesh: The context parallel mesh.
349349
cp_rank: The rank of the current context parallel process.
350350
"""
351-
self.dataloader = dataloader
351+
if cp_mesh.get_local_rank() == 0:
352+
assert dataloader is not None, "dataloader must be provided on rank 0"
353+
self.dataloader = dataloader
354+
355+
else:
356+
assert dataloader is None, "Dataloader on non-rank 0 will not be used"
357+
352358
self.cp_rank = cp_mesh.get_local_rank()
353359
self.cp_group = cp_mesh.get_group()
354360
self.num_cp_ranks = cp_mesh.size()
355361
self._iterator = None
356362

363+
logger.debug(
364+
"Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s",
365+
torch.distributed.get_rank() if torch.distributed.is_initialized() else "<not initialized>",
366+
self.cp_rank,
367+
)
368+
357369
def __iter__(self):
358370
"""Make the dataloader iterable."""
359-
self._iterator = iter(self.dataloader) # < --- collator output.
371+
if self.cp_rank == 0:
372+
self._iterator = iter(self.dataloader) # < --- collator output.
360373
return self
361374

362375
def __next__(self):
@@ -385,24 +398,19 @@ def _send_data_to_cp_ranks(self):
385398
batch: The batch for the current CP rank.
386399
387400
"""
388-
if self.cp_rank == 0:
389-
# Get data once, then make copies for each rank.
390-
if self._iterator is None:
391-
self._iterator = iter(self.dataloader)
392-
combined_batch = next(self._iterator)
401+
try:
402+
combined_batch = next(self._iterator) if self.cp_rank == 0 else None
403+
except StopIteration as ex:
404+
# If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so
405+
# that the dataloader can be restarted.
406+
combined_batch = [ex] * self.num_cp_ranks
393407

394-
else:
395-
combined_batch = None
396-
397-
scatter_object_output_list = [None]
398-
# Note: This does not provide an async_op handle. Thus its blocking.
399-
torch.distributed.scatter_object_list(
400-
scatter_object_output_list=scatter_object_output_list,
401-
scatter_object_input_list=combined_batch,
402-
group=self.cp_group,
403-
group_src=0,
404-
)
405-
return scatter_object_output_list[0]
408+
batch_on_this_rank = _scatter_batch_to_cp_ranks(combined_batch, self.cp_group)
409+
410+
if isinstance(batch_on_this_rank, StopIteration):
411+
raise batch_on_this_rank
412+
413+
return batch_on_this_rank
406414

407415

408416
def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -670,3 +678,31 @@ def _get_group_local_rank(group: torch.distributed.ProcessGroup | None = None) -
670678
return torch.distributed.get_rank()
671679
global_rank = torch.distributed.get_rank()
672680
return torch.distributed.get_group_rank(group, global_rank)
681+
682+
683+
class BatchType(TypedDict):
684+
"""The fields in the batch dictionary for context parallel."""
685+
686+
input_ids: torch.Tensor
687+
labels: torch.Tensor
688+
cu_seq_lens_q: torch.Tensor
689+
cu_seq_lens_k: torch.Tensor
690+
cu_seq_lens_q_padded: torch.Tensor
691+
cu_seq_lens_k_padded: torch.Tensor
692+
max_length_q: int
693+
max_length_k: int
694+
695+
696+
def _scatter_batch_to_cp_ranks(
697+
batch: list[BatchType] | list[StopIteration], cp_group: torch.distributed.ProcessGroup | None = None
698+
) -> BatchType | StopIteration:
699+
"""Scatter a batch to all the CP ranks."""
700+
scatter_object_output_list = [None]
701+
# Note: This does not provide an async_op handle. Thus its blocking.
702+
torch.distributed.scatter_object_list(
703+
scatter_object_output_list=scatter_object_output_list,
704+
scatter_object_input_list=batch,
705+
group=cp_group,
706+
group_src=0,
707+
)
708+
return scatter_object_output_list[0]

bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import copy
1717
import unittest
18-
from itertools import pairwise
1918
from typing import Dict, Iterator, List
2019
from unittest import mock
2120

@@ -229,57 +228,6 @@ def size(self) -> int:
229228
return self._size
230229

231230

232-
def _fake_get_batch(
233-
cu_seqlens_padded,
234-
input_ids_padded,
235-
labels_padded,
236-
cp_size,
237-
qvk_format,
238-
cp_rank,
239-
):
240-
total_slices = 2 * cp_size
241-
seq_tokens = input_ids_padded.view(-1)
242-
seq_labels = labels_padded.view(-1)
243-
shard_tokens: List[torch.Tensor] = []
244-
shard_labels: List[torch.Tensor] = []
245-
246-
for start, end in pairwise(cu_seqlens_padded):
247-
start_idx = int(start)
248-
end_idx = int(end)
249-
slice_size = (end_idx - start_idx) // total_slices
250-
251-
first_start = start_idx + (cp_rank * slice_size)
252-
first_end = first_start + slice_size
253-
second_start = start_idx + ((total_slices - cp_rank - 1) * slice_size)
254-
second_end = second_start + slice_size
255-
256-
shard_tokens.append(torch.cat([seq_tokens[first_start:first_end], seq_tokens[second_start:second_end]]))
257-
shard_labels.append(torch.cat([seq_labels[first_start:first_end], seq_labels[second_start:second_end]]))
258-
259-
return (
260-
torch.cat(shard_tokens).unsqueeze(0),
261-
torch.cat(shard_labels).unsqueeze(0),
262-
)
263-
264-
265-
def _make_cp_shards(base_batch: Dict[str, torch.Tensor], cp_size: int):
266-
combined_batch = []
267-
for cp_rank in range(cp_size):
268-
input_ids_sharded, labels_sharded = _fake_get_batch(
269-
cu_seqlens_padded=base_batch["cu_seq_lens_q_padded"],
270-
input_ids_padded=base_batch["input_ids"],
271-
labels_padded=base_batch["labels"],
272-
cp_size=cp_size,
273-
qvk_format="thd",
274-
cp_rank=cp_rank,
275-
)
276-
batch_shard = dict(base_batch)
277-
batch_shard["input_ids"] = input_ids_sharded
278-
batch_shard["labels"] = labels_sharded
279-
combined_batch.append(batch_shard)
280-
return combined_batch
281-
282-
283231
def test_pad_thd_sequences_for_cp():
284232
pid = 1 # The pad token id.
285233
label_pad = -100 # The label pad id.
@@ -410,7 +358,7 @@ def run_roundtrip(base_batch):
410358
cp_mesh_rank0 = _DummyDeviceMesh(size=cp_size, rank=0)
411359
cp_mesh_rank1 = _DummyDeviceMesh(size=cp_size, rank=1)
412360
loader_rank0 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank0)
413-
loader_rank1 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank1)
361+
loader_rank1 = ContextParallelDataLoaderWrapper(None, cp_mesh_rank1)
414362

415363
scatter_payload: Dict[str, List[Dict[str, torch.Tensor]]] = {}
416364
current_rank = {"value": None}
@@ -499,7 +447,7 @@ def run_roundtrip(base_batch):
499447
cp_mesh_rank0 = _DummyDeviceMesh(size=cp_size, rank=0)
500448
cp_mesh_rank1 = _DummyDeviceMesh(size=cp_size, rank=1)
501449
loader_rank0 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank0)
502-
loader_rank1 = ContextParallelDataLoaderWrapper(_DummyLoader(combined_batch), cp_mesh_rank1)
450+
loader_rank1 = ContextParallelDataLoaderWrapper(None, cp_mesh_rank1)
503451

504452
scatter_payload: Dict[str, List[Dict[str, torch.Tensor]]] = {}
505453
current_rank = {"value": None}

bionemo-recipes/recipes/esm2_native_te/collator.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import logging
2222
from dataclasses import dataclass
23-
from typing import Any
23+
from typing import Any, TypedDict
2424

2525
import datasets
2626
import torch
@@ -334,7 +334,7 @@ class ContextParallelDataLoaderWrapper:
334334

335335
def __init__(
336336
self,
337-
dataloader: torch.utils.data.DataLoader,
337+
dataloader: torch.utils.data.DataLoader | None,
338338
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
339339
):
340340
"""A dataloader wrapper that distributes the data across the context parallelism group.
@@ -348,15 +348,28 @@ def __init__(
348348
cp_mesh: The context parallel mesh.
349349
cp_rank: The rank of the current context parallel process.
350350
"""
351-
self.dataloader = dataloader
351+
if cp_mesh.get_local_rank() == 0:
352+
assert dataloader is not None, "dataloader must be provided on rank 0"
353+
self.dataloader = dataloader
354+
355+
else:
356+
assert dataloader is None, "Dataloader on non-rank 0 will not be used"
357+
352358
self.cp_rank = cp_mesh.get_local_rank()
353359
self.cp_group = cp_mesh.get_group()
354360
self.num_cp_ranks = cp_mesh.size()
355361
self._iterator = None
356362

363+
logger.debug(
364+
"Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s",
365+
torch.distributed.get_rank() if torch.distributed.is_initialized() else "<not initialized>",
366+
self.cp_rank,
367+
)
368+
357369
def __iter__(self):
358370
"""Make the dataloader iterable."""
359-
self._iterator = iter(self.dataloader) # < --- collator output.
371+
if self.cp_rank == 0:
372+
self._iterator = iter(self.dataloader) # < --- collator output.
360373
return self
361374

362375
def __next__(self):
@@ -385,24 +398,19 @@ def _send_data_to_cp_ranks(self):
385398
batch: The batch for the current CP rank.
386399
387400
"""
388-
if self.cp_rank == 0:
389-
# Get data once, then make copies for each rank.
390-
if self._iterator is None:
391-
self._iterator = iter(self.dataloader)
392-
combined_batch = next(self._iterator)
401+
try:
402+
combined_batch = next(self._iterator) if self.cp_rank == 0 else None
403+
except StopIteration as ex:
404+
# If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so
405+
# that the dataloader can be restarted.
406+
combined_batch = [ex] * self.num_cp_ranks
393407

394-
else:
395-
combined_batch = None
396-
397-
scatter_object_output_list = [None]
398-
# Note: This does not provide an async_op handle. Thus its blocking.
399-
torch.distributed.scatter_object_list(
400-
scatter_object_output_list=scatter_object_output_list,
401-
scatter_object_input_list=combined_batch,
402-
group=self.cp_group,
403-
group_src=0,
404-
)
405-
return scatter_object_output_list[0]
408+
batch_on_this_rank = _scatter_batch_to_cp_ranks(combined_batch, self.cp_group)
409+
410+
if isinstance(batch_on_this_rank, StopIteration):
411+
raise batch_on_this_rank
412+
413+
return batch_on_this_rank
406414

407415

408416
def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -670,3 +678,31 @@ def _get_group_local_rank(group: torch.distributed.ProcessGroup | None = None) -
670678
return torch.distributed.get_rank()
671679
global_rank = torch.distributed.get_rank()
672680
return torch.distributed.get_group_rank(group, global_rank)
681+
682+
683+
class BatchType(TypedDict):
684+
"""The fields in the batch dictionary for context parallel."""
685+
686+
input_ids: torch.Tensor
687+
labels: torch.Tensor
688+
cu_seq_lens_q: torch.Tensor
689+
cu_seq_lens_k: torch.Tensor
690+
cu_seq_lens_q_padded: torch.Tensor
691+
cu_seq_lens_k_padded: torch.Tensor
692+
max_length_q: int
693+
max_length_k: int
694+
695+
696+
def _scatter_batch_to_cp_ranks(
697+
batch: list[BatchType] | list[StopIteration], cp_group: torch.distributed.ProcessGroup | None = None
698+
) -> BatchType | StopIteration:
699+
"""Scatter a batch to all the CP ranks."""
700+
scatter_object_output_list = [None]
701+
# Note: This does not provide an async_op handle. Thus its blocking.
702+
torch.distributed.scatter_object_list(
703+
scatter_object_output_list=scatter_object_output_list,
704+
scatter_object_input_list=batch,
705+
group=cp_group,
706+
group_src=0,
707+
)
708+
return scatter_object_output_list[0]

bionemo-recipes/recipes/esm2_native_te/dataset.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,16 @@ def create_cp_dataloader(
252252
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
253253
kwargs["pad_sequences_to_be_divisible_by"] = cp_mesh.size() * 2
254254

255-
train_dataloader, tokenized_dataset = create_thd_dataloader(*args, **kwargs)
255+
if cp_mesh.get_local_rank() == 0:
256+
train_dataloader, tokenized_dataset = create_thd_dataloader(*args, **kwargs)
256257

257-
train_dataloader.collate_fn = DataCollatorForContextParallel(
258-
collator=train_dataloader.collate_fn,
259-
cp_world_size=cp_mesh.size(),
260-
)
258+
train_dataloader.collate_fn = DataCollatorForContextParallel(
259+
collator=train_dataloader.collate_fn,
260+
cp_world_size=cp_mesh.size(),
261+
)
262+
263+
else:
264+
train_dataloader = None
265+
tokenized_dataset = None
261266

262267
return ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh), tokenized_dataset

bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ def main(args: DictConfig) -> float | None:
197197

198198
# Dataloader exhausted, incrementing epoch
199199
epoch += 1
200-
dataset_or_sampler.set_epoch(epoch)
200+
if dataset_or_sampler is not None: # The dataset only exists on rank 0
201+
dataset_or_sampler.set_epoch(epoch)
201202

202203
# Save final model to a .safetensors file.
203204
if args.checkpoint.save_final_model and ckpt_path:

bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ def main(args: DictConfig) -> float | None:
214214

215215
# Dataloader exhausted, incrementing epoch
216216
epoch += 1
217-
dataset_or_sampler.set_epoch(epoch)
217+
if dataset_or_sampler is not None: # The dataset only exists on rank 0
218+
dataset_or_sampler.set_epoch(epoch)
218219

219220
# Save final model to a .safetensors file.
220221
if args.checkpoint.save_final_model and ckpt_path:

bionemo-recipes/recipes/llama3_native_te/checkpoint.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,13 +328,15 @@ def save_checkpoint_fsdp2(
328328
)
329329
logger.info(f"Saved FSDP2 dataloader to {ckpt_path}")
330330

331-
# If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time.
332-
if async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None:
333-
_ckpt_futures["fsdp2"].result()
334-
335331
state_dict = {"app": AppState(model=model, optimizer=optimizer, scheduler=scheduler, step=step, epoch=epoch)}
336-
ckpt_save_func = dcp_async_save if async_save else dcp_save
337-
_ckpt_futures["fsdp2"] = ckpt_save_func(state_dict, checkpoint_id=checkpoint_path, process_group=process_group)
332+
if async_save:
333+
# If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time.
334+
if "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None:
335+
_ckpt_futures["fsdp2"].result()
336+
337+
_ckpt_futures["fsdp2"] = dcp_async_save(state_dict, checkpoint_id=checkpoint_path, process_group=process_group)
338+
else:
339+
dcp_save(state_dict, checkpoint_id=checkpoint_path, process_group=process_group)
338340

339341
if max_checkpoints is not None and dist_config.is_main_process():
340342
prune_checkpoints(ckpt_path, max_checkpoints)

0 commit comments

Comments
 (0)