Skip to content

Commit 7c4ecc7

Browse files
authored
Demonstrate dcp checkpoint save resume with context parallel (#1421)
Makes it easier to run BSHD context parallel runs in the llama3 recipe for local testing, and adds checkpoint save/resume checks to the llama3 recipe BIO-8 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added checkpoint save/restore capabilities and worker configuration querying for context-parallel data loading * Introduced new configuration parameter for sequence padding control * **Improvements** * Enhanced distributed training support with improved checkpoint integration for context parallelism * **Tests** * Added comprehensive integration tests for distributed checkpointing with context parallelism * Added multi-GPU training tests with different attention format configurations <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent b28d4e2 commit 7c4ecc7

12 files changed

Lines changed: 476 additions & 84 deletions

File tree

bionemo-recipes/models/esm2/src/esm/collator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self):
415415

416416
return batch_on_this_rank
417417

418+
def state_dict(self):
419+
"""Get the state dict by delegating to the dataloader."""
420+
if self.cp_rank != 0:
421+
return {}
422+
elif hasattr(self.dataloader, "state_dict"):
423+
return {"dataloader": self.dataloader.state_dict()}
424+
else:
425+
logger.warning(
426+
"Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, "
427+
"returning empty dict"
428+
)
429+
return {"dataloader": {}}
430+
431+
def load_state_dict(self, state_dict):
432+
"""Load the state dict by delegating to the dataloader."""
433+
if self.cp_rank != 0:
434+
return
435+
elif hasattr(self.dataloader, "load_state_dict"):
436+
self.dataloader.load_state_dict(state_dict["dataloader"])
437+
else:
438+
logger.warning(
439+
"Attempting to load the state dict of the dataloader, but the dataloader does not support "
440+
"load_state_dict, returning without loading the state dict."
441+
)
442+
return
443+
444+
@property
445+
def num_workers(self):
446+
"""Get the number of workers of the dataloader."""
447+
if self.cp_rank != 0:
448+
return 0
449+
else:
450+
return self.dataloader.num_workers
451+
418452

419453
def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
420454
"""Split a sample dictionary at a specified number of tokens.

bionemo-recipes/models/llama3/collator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self):
415415

416416
return batch_on_this_rank
417417

418+
def state_dict(self):
419+
"""Get the state dict by delegating to the dataloader."""
420+
if self.cp_rank != 0:
421+
return {}
422+
elif hasattr(self.dataloader, "state_dict"):
423+
return {"dataloader": self.dataloader.state_dict()}
424+
else:
425+
logger.warning(
426+
"Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, "
427+
"returning empty dict"
428+
)
429+
return {"dataloader": {}}
430+
431+
def load_state_dict(self, state_dict):
432+
"""Load the state dict by delegating to the dataloader."""
433+
if self.cp_rank != 0:
434+
return
435+
elif hasattr(self.dataloader, "load_state_dict"):
436+
self.dataloader.load_state_dict(state_dict["dataloader"])
437+
else:
438+
logger.warning(
439+
"Attempting to load the state dict of the dataloader, but the dataloader does not support "
440+
"load_state_dict, returning without loading the state dict."
441+
)
442+
return
443+
444+
@property
445+
def num_workers(self):
446+
"""Get the number of workers of the dataloader."""
447+
if self.cp_rank != 0:
448+
return 0
449+
else:
450+
return self.dataloader.num_workers
451+
418452

419453
def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
420454
"""Split a sample dictionary at a specified number of tokens.

bionemo-recipes/recipes/esm2_native_te/collator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self):
415415

416416
return batch_on_this_rank
417417

418+
def state_dict(self):
419+
"""Get the state dict by delegating to the dataloader."""
420+
if self.cp_rank != 0:
421+
return {}
422+
elif hasattr(self.dataloader, "state_dict"):
423+
return {"dataloader": self.dataloader.state_dict()}
424+
else:
425+
logger.warning(
426+
"Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, "
427+
"returning empty dict"
428+
)
429+
return {"dataloader": {}}
430+
431+
def load_state_dict(self, state_dict):
432+
"""Load the state dict by delegating to the dataloader."""
433+
if self.cp_rank != 0:
434+
return
435+
elif hasattr(self.dataloader, "load_state_dict"):
436+
self.dataloader.load_state_dict(state_dict["dataloader"])
437+
else:
438+
logger.warning(
439+
"Attempting to load the state dict of the dataloader, but the dataloader does not support "
440+
"load_state_dict, returning without loading the state dict."
441+
)
442+
return
443+
444+
@property
445+
def num_workers(self):
446+
"""Get the number of workers of the dataloader."""
447+
if self.cp_rank != 0:
448+
return 0
449+
else:
450+
return self.dataloader.num_workers
451+
418452

419453
def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
420454
"""Split a sample dictionary at a specified number of tokens.

bionemo-recipes/recipes/llama3_native_te/collator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,40 @@ def _send_data_to_cp_ranks(self):
415415

416416
return batch_on_this_rank
417417

418+
def state_dict(self):
419+
"""Get the state dict by delegating to the dataloader."""
420+
if self.cp_rank != 0:
421+
return {}
422+
elif hasattr(self.dataloader, "state_dict"):
423+
return {"dataloader": self.dataloader.state_dict()}
424+
else:
425+
logger.warning(
426+
"Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, "
427+
"returning empty dict"
428+
)
429+
return {"dataloader": {}}
430+
431+
def load_state_dict(self, state_dict):
432+
"""Load the state dict by delegating to the dataloader."""
433+
if self.cp_rank != 0:
434+
return
435+
elif hasattr(self.dataloader, "load_state_dict"):
436+
self.dataloader.load_state_dict(state_dict["dataloader"])
437+
else:
438+
logger.warning(
439+
"Attempting to load the state dict of the dataloader, but the dataloader does not support "
440+
"load_state_dict, returning without loading the state dict."
441+
)
442+
return
443+
444+
@property
445+
def num_workers(self):
446+
"""Get the number of workers of the dataloader."""
447+
if self.cp_rank != 0:
448+
return 0
449+
else:
450+
return self.dataloader.num_workers
451+
418452

419453
def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]:
420454
"""Split a sample dictionary at a specified number of tokens.

bionemo-recipes/recipes/llama3_native_te/dataset.py

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,12 @@
1717

1818
import datasets
1919
import datasets.distributed
20-
import torch
2120
from torch.utils.data import DataLoader, DistributedSampler
2221
from torchdata.stateful_dataloader import StatefulDataLoader
2322
from transformers import AutoTokenizer
2423
from transformers.data.data_collator import DataCollatorForLanguageModeling
2524

2625
from collator import (
27-
ContextParallelDataLoaderWrapper,
28-
DataCollatorForContextParallel,
2926
DataCollatorWithFlattening,
3027
TokenPackingDataset,
3128
)
@@ -102,6 +99,11 @@ def tokenize_with_windowing(examples):
10299
remove_columns=[text_column],
103100
)
104101

102+
# Even in THD mode, we use a base MLM collator that requires a padding token to be set.
103+
if tokenizer.pad_token is None:
104+
logger.warning(f"Tokenizer does not have a padding token. Setting it to the EOS token: {tokenizer.eos_token}")
105+
tokenizer.pad_token = tokenizer.eos_token
106+
105107
return tokenized_dataset, tokenizer
106108

107109

@@ -120,7 +122,7 @@ def create_bshd_dataloader(
120122
text_column: str = "text",
121123
uppercase_labels: bool = False,
122124
mask_degenerate_bases: bool = False,
123-
pad_to_multiple_of: int | None = None,
125+
pad_sequences_to_be_divisible_by: int | None = None,
124126
):
125127
"""Create a BSHD dataloader for llama3 pre-training.
126128
@@ -139,7 +141,8 @@ def create_bshd_dataloader(
139141
text_column: Name of the column containing text sequences (default: "text").
140142
uppercase_labels: Whether to uppercase labels (genomic masking). Default: False.
141143
mask_degenerate_bases: Whether to mask non-ACGT bases (genomic masking). Default: False.
142-
pad_to_multiple_of: The number to pad sequences to be divisible by, required for FP8 training. Default: 16.
144+
pad_sequences_to_be_divisible_by: The number to pad sequences to be divisible by, required for FP8 training.
145+
Default: None.
143146
144147
Returns:
145148
A tuple of (dataloader, dataset_or_sampler).
@@ -169,7 +172,7 @@ def create_bshd_dataloader(
169172
base_collator = DataCollatorForLanguageModeling(
170173
tokenizer=tokenizer,
171174
mlm=False, # Causal language modeling
172-
pad_to_multiple_of=pad_to_multiple_of,
175+
pad_to_multiple_of=pad_sequences_to_be_divisible_by,
173176
)
174177

175178
# Wrap with genomic collator if masking options are enabled
@@ -300,40 +303,3 @@ def create_thd_dataloader(
300303
)
301304

302305
return train_dataloader, tokenized_dataset
303-
304-
305-
def create_cp_dataloader(
306-
*args,
307-
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
308-
**kwargs,
309-
):
310-
"""Create a Context-parallel aware dataloader that automatically handles sharding between ranks.
311-
312-
Wraps the output of `create_thd_dataloader` to make it context parallel aware.
313-
314-
Args:
315-
*args: Arguments to pass to `create_thd_dataloader`.
316-
cp_mesh: The context parallel mesh.
317-
**kwargs: Keyword arguments to pass to `create_thd_dataloader`.
318-
319-
Returns:
320-
A tuple of (dataloader, dataset_or_sampler).
321-
"""
322-
# Ensure pad_sequences_to_be_divisible_by is passed to create_thd_dataloader
323-
if kwargs.get("pad_sequences_to_be_divisible_by", None) is None:
324-
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
325-
kwargs["pad_sequences_to_be_divisible_by"] = cp_mesh.size() * 2
326-
327-
if cp_mesh.get_local_rank() == 0:
328-
train_dataloader, tokenized_dataset = create_thd_dataloader(*args, **kwargs)
329-
330-
train_dataloader.collate_fn = DataCollatorForContextParallel(
331-
collator=train_dataloader.collate_fn,
332-
cp_world_size=cp_mesh.size(),
333-
)
334-
335-
else:
336-
train_dataloader = None
337-
tokenized_dataset = None
338-
339-
return ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh), tokenized_dataset

bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ defaults:
44

55
cp_size: 1
66

7+
use_sequence_packing: false
8+
79
config_kwargs:
8-
attn_input_format: "thd"
9-
self_attn_mask_type: "padding_causal"
10+
attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
11+
self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dataset:
2222
stride: 200 # Overlap for windowing
2323
buffer_size: 500_000 # Shuffle buffer size
2424
use_stateful_dataloader: false # Until https://github.com/pytorch/pytorch/pull/163102 is resolved with torchdata.
25+
pad_sequences_to_be_divisible_by: null
2526
load_dataset_kwargs:
2627
path: ???
2728
split: "train"

bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from hydra import compose, initialize_config_dir
2525
from torch.distributed.device_mesh import init_device_mesh
2626

27-
from dataset import create_bshd_dataloader, create_cp_dataloader, create_thd_dataloader, create_tokenized_dataset
27+
from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel
28+
from dataset import create_bshd_dataloader, create_thd_dataloader, create_tokenized_dataset
2829
from distributed_config import DistributedConfig
2930

3031

@@ -703,15 +704,28 @@ def test_cp_dataloader(tokenizer_path):
703704
torch.cuda.set_device(dist_config.local_rank)
704705
device_mesh = init_device_mesh("cuda", mesh_shape=(1, 1), mesh_dim_names=("dp", "cp"))
705706

706-
dataloader, _ = create_cp_dataloader(
707-
distributed_config=dist_config,
708-
cp_mesh=device_mesh["cp"],
709-
tokenizer_name_or_path=tokenizer_path,
710-
load_dataset_kwargs=load_dataset_kwargs,
711-
text_column="text",
712-
micro_batch_size=1,
713-
max_seq_length=1024,
714-
)
707+
cp_mesh = device_mesh["cp"]
708+
709+
# Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py
710+
if cp_mesh.get_local_rank() == 0:
711+
train_dataloader, _ = create_thd_dataloader(
712+
distributed_config=dist_config,
713+
tokenizer_name_or_path=tokenizer_path,
714+
load_dataset_kwargs=load_dataset_kwargs,
715+
text_column="text",
716+
micro_batch_size=1,
717+
max_seq_length=1024,
718+
pad_sequences_to_be_divisible_by=cp_mesh.size() * 2,
719+
)
720+
721+
train_dataloader.collate_fn = DataCollatorForContextParallel(
722+
collator=train_dataloader.collate_fn,
723+
cp_world_size=cp_mesh.size(),
724+
)
725+
else:
726+
train_dataloader = None
727+
728+
dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh)
715729

716730
batches = list(dataloader)
717731
assert len(batches) > 1
@@ -775,30 +789,39 @@ def test_cp_dataloader_multi_gpu(recipe_path, dataset_path):
775789
parser.add_argument("--dataset_path", type=str, default="dlcm_sanity_dataset.parquet")
776790
args = parser.parse_args()
777791

778-
from torch.distributed.device_mesh import init_device_mesh
779-
780-
from dataset import create_cp_dataloader
781-
782792
dist_config = DistributedConfig()
783793
device = torch.device(f"cuda:{dist_config.local_rank}")
784794
torch.distributed.init_process_group(backend="nccl", device_id=device)
785795
torch.cuda.set_device(dist_config.local_rank)
786796
device_mesh = init_device_mesh("cuda", mesh_shape=(1, 2), mesh_dim_names=("dp", "cp"))
787797

788-
dataloader, _ = create_cp_dataloader(
789-
distributed_config=dist_config,
790-
cp_mesh=device_mesh["cp"],
791-
tokenizer_name_or_path="nvidia/Llama-3.1-8B-Instruct-FP8",
792-
micro_batch_size=1,
793-
text_column="text" if args.dataset_path == "dlcm_sanity_dataset.parquet" else "sequence",
794-
load_dataset_kwargs={
795-
"path": "parquet",
796-
"split": "train",
797-
"data_files": args.dataset_path,
798-
"streaming": True,
799-
},
800-
num_workers=1,
801-
)
798+
cp_mesh = device_mesh["cp"]
799+
800+
# Create the context-parallel dataloader directly following the pattern in train_fsdp2_cp.py
801+
if cp_mesh.get_local_rank() == 0:
802+
train_dataloader, _ = create_thd_dataloader(
803+
distributed_config=dist_config,
804+
tokenizer_name_or_path="nvidia/Llama-3.1-8B-Instruct-FP8",
805+
micro_batch_size=1,
806+
text_column="text" if args.dataset_path == "dlcm_sanity_dataset.parquet" else "sequence",
807+
load_dataset_kwargs={
808+
"path": "parquet",
809+
"split": "train",
810+
"data_files": args.dataset_path,
811+
"streaming": True,
812+
},
813+
num_workers=1,
814+
pad_sequences_to_be_divisible_by=cp_mesh.size() * 2,
815+
)
816+
817+
train_dataloader.collate_fn = DataCollatorForContextParallel(
818+
collator=train_dataloader.collate_fn,
819+
cp_world_size=cp_mesh.size(),
820+
)
821+
else:
822+
train_dataloader = None
823+
824+
dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_mesh)
802825

803826
batches = list(itertools.islice(dataloader, 10))
804827

0 commit comments

Comments
 (0)