diff --git a/docs/components/components.md b/docs/components/components.md index 81af02d34..904023241 100644 --- a/docs/components/components.md +++ b/docs/components/components.md @@ -2,130 +2,131 @@ ## Models -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| model | gpt2 | [GPT2LLM](../../src/modalities/models/gpt2/gpt2_model.py)| [GPT2LLMConfig](../../src/modalities/models/gpt2/gpt2_model.py) | [NNModel](../../src/modalities/models/model.py) | GPT2 model for language modeling | -| model | huggingface_pretrained_model | [HuggingFacePretrainedModel](../../src/modalities/models/huggingface/huggingface_model.py)| [HuggingFacePretrainedModelConfig](../../src/modalities/models/huggingface/huggingface_model.py) | [NNModel](../../src/modalities/models/model.py) | HuggingFace pretrained model for language modeling | -| model | checkpointed | [ModelFactory.get_checkpointed_model](../../src/modalities/models/model_factory.py)| [CheckpointedModelConfig](../../src/modalities/config/config.py) | [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) | Checkpointed Model instance | -| model | fsdp_wrapped | [ModelFactory.get_fsdp_wrapped_model](../../src/modalities/models/model_factory.py)| [FSDPWrappedModelConfig](../../src/modalities/config/config.py) | [NNModel](../../src/modalities/models/model.py) | Model that has been sharded via FSDP | -| model | model_initialized | [ModelFactory.get_weight_initialized_model](../../src/modalities/models/model_factory.py)| [WeightInitializedModelConfig](../../src/modalities/config/config.py) | [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) | Model with initialized weights | -| model | coca | [CoCa](../../src/modalities/models/coca/coca_model.py)| [CoCaConfig](../../src/modalities/models/coca/coca_model.py) | [NNModel](../../src/modalities/models/model.py) |[CoCa Model (Contrastive Captioners) ](https://arxiv.org/abs/2205.01917) | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|------------------------------|--------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------|--------------------------------------------------------------------------| +| model | gpt2 | [GPT2LLM](../../src/modalities/models/gpt2/gpt2_model.py) | [GPT2LLMConfig](../../src/modalities/models/gpt2/gpt2_model.py) | [NNModel](../../src/modalities/models/model.py) | GPT2 model for language modeling | +| model | huggingface_pretrained_model | [HuggingFacePretrainedModel](../../src/modalities/models/huggingface/huggingface_model.py) | [HuggingFacePretrainedModelConfig](../../src/modalities/models/huggingface/huggingface_model.py) | [NNModel](../../src/modalities/models/model.py) | HuggingFace pretrained model for language modeling | +| model | checkpointed | [ModelFactory.get_checkpointed_model](../../src/modalities/models/model_factory.py) | [CheckpointedModelConfig](../../src/modalities/config/config.py) | [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) | Checkpointed Model instance | +| model | fsdp_wrapped | [ModelFactory.get_fsdp_wrapped_model](../../src/modalities/models/model_factory.py) | [FSDPWrappedModelConfig](../../src/modalities/config/config.py) | [NNModel](../../src/modalities/models/model.py) | Model that has been sharded via FSDP | +| model | model_initialized | [ModelFactory.get_weight_initialized_model](../../src/modalities/models/model_factory.py) | [WeightInitializedModelConfig](../../src/modalities/config/config.py) | [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) | Model with initialized weights | +| model | coca | [CoCa](../../src/modalities/models/coca/coca_model.py) | [CoCaConfig](../../src/modalities/models/coca/coca_model.py) | [NNModel](../../src/modalities/models/model.py) | [CoCa Model (Contrastive Captioners) ](https://arxiv.org/abs/2205.01917) | ## Weight Initialization -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| model_initialization | composed | [ComposedInitializationRoutines.get_composed_model_initializer](../../src/modalities/nn/model_initialization/composed_initialization.py)| [ComposedModelInitializationConfig](../../src/modalities/nn/model_initialization/composed_initialization.py) | [ModelInitializationIF](../../src/modalities/nn/model_initialization/initialization_if.py) | Component for initializing model weights in place | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------------|-------------------|------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------|---------------------------------------------------| +| model_initialization | composed | [ComposedInitializationRoutines.get_composed_model_initializer](../../src/modalities/nn/model_initialization/composed_initialization.py) | [ComposedModelInitializationConfig](../../src/modalities/nn/model_initialization/composed_initialization.py) | [ModelInitializationIF](../../src/modalities/nn/model_initialization/initialization_if.py) | Component for initializing model weights in place | The composed initializer supports seeded weight initialization for reproducibility within a fixed topology. When pipeline parallelism is active, Modalities offsets the initialization seed by pipeline stage rank to avoid identical stage-local weights. As a result, the same seed can produce different initialized weights for different pipeline-parallel topologies. For topology-independent reproducibility, create and reuse a distributed checkpoint directly after weight initialization. ## Losses -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| loss | clm_cross_entropy_loss | [CLMCrossEntropyLoss](../../src/modalities/loss_functions.py)| [CLMCrossEntropyLossConfig](../../src/modalities/config/config.py) | [Loss](../../src/modalities/loss_functions.py) | Cross-entropy loss function | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|------------------------|---------------------------------------------------------------|--------------------------------------------------------------------|------------------------------------------------|-----------------------------| +| loss | clm_cross_entropy_loss | [CLMCrossEntropyLoss](../../src/modalities/loss_functions.py) | [CLMCrossEntropyLossConfig](../../src/modalities/config/config.py) | [Loss](../../src/modalities/loss_functions.py) | Cross-entropy loss function | ## Optimizers -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| optimizer | adam | [OptimizerFactory.get_adam](../../src/modalities/optimizers/optimizer_factory.py)| [AdamOptimizerConfig](../../src/modalities/config/config.py) | [Optimizer](../../src/modalities/models/model.py) | ADAM optimizer | -| optimizer | adam_w | [OptimizerFactory.get_adam_w](../../src/modalities/optimizers/optimizer_factory.py)| [AdamWOptimizerConfig](../../src/modalities/config/config.py) | [Optimizer](../../src/modalities/models/model.py) | ADAMW Optimizer | -| optimizer | checkpointed | [OptimizerFactory.get_checkpointed_optimizer](../../src/modalities/optimizers/optimizer_factory.py)| [CheckpointedOptimizerConfig](../../src/modalities/config/config.py) | [Optimizer](../../src/modalities/models/model.py) | Optimizer instantiated from checkpoint | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|-------------------|-----------------------------------------------------------------------------------------------------|----------------------------------------------------------------------|---------------------------------------------------|----------------------------------------| +| optimizer | adam | [OptimizerFactory.get_adam](../../src/modalities/optimizers/optimizer_factory.py) | [AdamOptimizerConfig](../../src/modalities/config/config.py) | [Optimizer](../../src/modalities/models/model.py) | ADAM optimizer | +| optimizer | adam_w | [OptimizerFactory.get_adam_w](../../src/modalities/optimizers/optimizer_factory.py) | [AdamWOptimizerConfig](../../src/modalities/config/config.py) | [Optimizer](../../src/modalities/models/model.py) | ADAMW Optimizer | +| optimizer | checkpointed | [OptimizerFactory.get_checkpointed_optimizer](../../src/modalities/optimizers/optimizer_factory.py) | [CheckpointedOptimizerConfig](../../src/modalities/config/config.py) | [Optimizer](../../src/modalities/models/model.py) | Optimizer instantiated from checkpoint | ## LR Scheduling -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| scheduler | dummy_lr | [DummyLRScheduler](../../src/modalities/optimizers/lr_schedulers.py)| [DummyLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Fake lr scheduler not adapting the lr rate | -| scheduler | step_lr | [StepLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html)| [StepLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Decays the learning rate of each parameter group by gamma every step_size steps | -| scheduler | constant_lr | [ConstantLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ConstantLR.html#torch.optim.lr_scheduler.ConstantLR)| [ConstantLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Multiplies the learning rate of each parameter group by a small constant factor until the number of steps reaches a pre-defined milestone | -| scheduler | onecycle_lr | [OneCycleLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR)| [OneCycleLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Sets the learning rate of each parameter group according to the 1cycle learning rate policy. | -| scheduler | cosine_annealing_lr | [CosineAnnealingLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR)| [CosineAnnealingLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Set the learning rate of each parameter group using a cosine annealing schedule | -| scheduler | linear_warmup_cosine_annealing_lr | [LinearWarmupCosineAnnealingLRScheduler](../../src/modalities/optimizers/lr_schedulers.py) | [LinearWarmupCosineAnnealingLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Linearly warms up to the base learning rate, then decays with cosine annealing for the remaining training steps | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|-----------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------| +| scheduler | dummy_lr | [DummyLRScheduler](../../src/modalities/optimizers/lr_schedulers.py) | [DummyLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Fake lr scheduler not adapting the lr rate | +| scheduler | step_lr | [StepLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html) | [StepLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Decays the learning rate of each parameter group by gamma every step_size steps | +| scheduler | constant_lr | [ConstantLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ConstantLR.html#torch.optim.lr_scheduler.ConstantLR) | [ConstantLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Multiplies the learning rate of each parameter group by a small constant factor until the number of steps reaches a pre-defined milestone | +| scheduler | onecycle_lr | [OneCycleLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR) | [OneCycleLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Sets the learning rate of each parameter group according to the 1cycle learning rate policy. | +| scheduler | cosine_annealing_lr | [CosineAnnealingLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR) | [CosineAnnealingLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Set the learning rate of each parameter group using a cosine annealing schedule | +| scheduler | linear_warmup_cosine_annealing_lr | [LinearWarmupCosineAnnealingLRScheduler](../../src/modalities/optimizers/lr_schedulers.py) | [LinearWarmupCosineAnnealingLRSchedulerConfig](../../src/modalities/config/config.py) | [LRScheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) | Linearly warms up to the base learning rate, then decays with cosine annealing for the remaining training steps | ## Tokenization -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| tokenizer | pretrained_hf_tokenizer | [PreTrainedHFTokenizer](../../src/modalities/tokenization/tokenizer_wrapper.py) | [PreTrainedHFTokenizerConfig](../../src/modalities/config/config.py) | [TokenizerWrapper](../../src/modalities/tokenization/tokenizer_wrapper.py) | Pretrained Huggingface tokenizer | -| tokenizer | pretrained_sp_tokenizer | [PreTrainedSPTokenizer](../../src/modalities/tokenization/tokenizer_wrapper.py) | [PreTrainedSPTokenizerConfig](../../src/modalities/config/config.py) | [TokenizerWrapper](../../src/modalities/tokenization/tokenizer_wrapper.py) | Pretrained SentencePiece tokenizer | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|-------------------------|---------------------------------------------------------------------------------|----------------------------------------------------------------------|----------------------------------------------------------------------------|------------------------------------| +| tokenizer | pretrained_hf_tokenizer | [PreTrainedHFTokenizer](../../src/modalities/tokenization/tokenizer_wrapper.py) | [PreTrainedHFTokenizerConfig](../../src/modalities/config/config.py) | [TokenizerWrapper](../../src/modalities/tokenization/tokenizer_wrapper.py) | Pretrained Huggingface tokenizer | +| tokenizer | pretrained_sp_tokenizer | [PreTrainedSPTokenizer](../../src/modalities/tokenization/tokenizer_wrapper.py) | [PreTrainedSPTokenizerConfig](../../src/modalities/config/config.py) | [TokenizerWrapper](../../src/modalities/tokenization/tokenizer_wrapper.py) | Pretrained SentencePiece tokenizer | ## Datasets -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| dataset | mem_map_dataset | [DatasetFactory.get_mem_map_dataset](../../src/modalities/dataloader/dataset_factory.py)| [MemMapDatasetConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | MemMap Dataset | -| dataset | packed_mem_map_dataset_continuous | [DatasetFactory.get_packed_mem_map_dataset_continuous](../../src/modalities/dataloader/dataset_factory.py)| [PackedMemMapDatasetContinuousConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Packed Memory Mapped Dataset Continuous | -| dataset | dummy_dataset | [DatasetFactory.get_dummy_dataset](../../src/modalities/dataloader/dataset_factory.py)| [DummyDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dummy dataset creating random samples of specified shape | -| dataset | combined | [DatasetFactory.get_combined_dataset](../../src/modalities/dataloader/dataset_factory.py)| [CombinedDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dataset implementation combining multiple datasets into one. | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|-----------------------------------|------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------|-------------------------------------------------------|--------------------------------------------------------------| +| dataset | mem_map_dataset | [DatasetFactory.get_mem_map_dataset](../../src/modalities/dataloader/dataset_factory.py) | [MemMapDatasetConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | MemMap Dataset | +| dataset | packed_mem_map_dataset_continuous | [DatasetFactory.get_packed_mem_map_dataset_continuous](../../src/modalities/dataloader/dataset_factory.py) | [PackedMemMapDatasetContinuousConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Packed Memory Mapped Dataset Continuous | +| dataset | dummy_dataset | [DatasetFactory.get_dummy_dataset](../../src/modalities/dataloader/dataset_factory.py) | [DummyDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dummy dataset creating random samples of specified shape | +| dataset | combined | [DatasetFactory.get_combined_dataset](../../src/modalities/dataloader/dataset_factory.py) | [CombinedDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dataset implementation combining multiple datasets into one. | ## Data sampling -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| sampler | distributed_sampler | [DistributedSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler)| [DistributedSamplerConfig](../../src/modalities/config/config.py) | [Sampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler) | Sampler that restricts data loading to a subset of the dataset for distributed training | -| batch_sampler | default | [BatchSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.BatchSampler) | [BatchSamplerConfig](../../src/modalities/config/config.py) | [Sampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler) | Wraps another sampler to yield a mini-batch of indices. | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|---------------------|-----------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------|-------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------| +| sampler | distributed_sampler | [DistributedSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler) | [DistributedSamplerConfig](../../src/modalities/config/config.py) | [Sampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler) | Sampler that restricts data loading to a subset of the dataset for distributed training | +| batch_sampler | default | [BatchSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.BatchSampler) | [BatchSamplerConfig](../../src/modalities/config/config.py) | [Sampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler) | Wraps another sampler to yield a mini-batch of indices. | ## Data collation -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| collate_fn | gpt_2_llm_collator | [GPT2LLMCollateFn](../../src/modalities/models/gpt2/collator.py)| [GPT2LLMCollateFnConfig](../../src/modalities/config/config.py) | [CollateFnIF](../../src/modalities/models/gpt2/collator.py) | Data collator for the GPT2 model | -| collate_fn | coca_collator | [CoCaCollatorFn](../../src/modalities/models/gpt2/collator.py)| [CoCaCollateFnConfig](../../src/modalities/config/config.py) | [CollateFnIF](../../src/modalities/models/gpt2/collator.py) | Data collator for the CoCa model | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|--------------------|------------------------------------------------------------------|-----------------------------------------------------------------|-------------------------------------------------------------|----------------------------------| +| collate_fn | gpt_2_llm_collator | [GPT2LLMCollateFn](../../src/modalities/models/gpt2/collator.py) | [GPT2LLMCollateFnConfig](../../src/modalities/config/config.py) | [CollateFnIF](../../src/modalities/models/gpt2/collator.py) | Data collator for the GPT2 model | +| collate_fn | coca_collator | [CoCaCollatorFn](../../src/modalities/models/gpt2/collator.py) | [CoCaCollateFnConfig](../../src/modalities/config/config.py) | [CollateFnIF](../../src/modalities/models/gpt2/collator.py) | Data collator for the CoCa model | ## Data loaders -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| data_loader | default | [DataloaderFactory.get_dataloader](../../src/modalities/dataloader/dataloader_factory.py)| [LLMDataLoaderConfig](s../../src/modalities/config/config.py) | [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | LLM Data loader extending pytorch data loader functionality | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|-------------------|-------------------------------------------------------------------------------------------|---------------------------------------------------------------|-------------------------------------------------------------------------------------|-------------------------------------------------------------| +| data_loader | default | [DataloaderFactory.get_dataloader](../../src/modalities/dataloader/dataloader_factory.py) | [LLMDataLoaderConfig](s../../src/modalities/config/config.py) | [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | LLM Data loader extending pytorch data loader functionality | ## Checkpointing -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| checkpoint_saving | default | [CheckpointSaving](../../src/modalities/checkpointing/checkpoint_saving.py)| [CheckpointSavingConfig](s../../src/modalities/config/config.py) | -- | Component for saving checkpoints based on a savig and execution strategy. | -| checkpoint_saving_strategy | save_every_k_steps_checkpointing_strategy | [SaveEveryKStepsCheckpointingStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py)| [SaveEveryKStepsCheckpointingStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving a checkpoint every k steps | -| checkpoint_saving_strategy | save_k_most_recent_checkpoints_strategy | [SaveKMostRecentCheckpointsStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py)| [SaveKMostRecentCheckpointsStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving only the last k checkpoints and deleting the previous ones | -| checkpoint_saving_execution | fsdp | [FSDPCheckpointSaving](../../src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py)| [FSDPCheckpointSavingConfig](../../src/modalities/config/config.py) | [CheckpointSavingExecutionABC](../../src/modalities/checkpointing/checkpoint_saving_execution.py) | FSDPCheckpointSaving class for saving checkpoints of FSDP models and optimizers. | -| checkpoint_loading | fsdp | [FSDPCheckpointLoading](../../src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py)| [FSDPCheckpointLoadingConfig](../../src/modalities/config/config.py) | [CheckpointLoadingIF](../../src/modalities/checkpointing/checkpoint_loading.py) | Component for loading FSDP checkpoints| -| checkpoint_loading | torch | [TorchCheckpointLoading](../../src/modalities/checkpointing/torch/torch_checkpoint_loading.py)| [TorchCheckpointLoadingConfig](../../src/modalities/config/config.py) | [CheckpointLoadingIF](../../src/modalities/checkpointing/checkpoint_loading.py) | Component for loading PyTorch checkpoints| +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|-----------------------------|-------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------| +| checkpoint_saving | default | [CheckpointSaving](../../src/modalities/checkpointing/checkpoint_saving.py) | [CheckpointSavingConfig](s../../src/modalities/config/config.py) | -- | Component for saving checkpoints based on a savig and execution strategy. | +| checkpoint_saving_strategy | save_every_k_steps_checkpointing_strategy | [SaveEveryKStepsCheckpointingStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | [SaveEveryKStepsCheckpointingStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving a checkpoint every k steps | +| checkpoint_saving_strategy | save_k_most_recent_checkpoints_strategy | [SaveKMostRecentCheckpointsStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | [SaveKMostRecentCheckpointsStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving only the last k checkpoints and deleting the previous ones | +| checkpoint_saving_strategy | keep_every_k_steps_and_m_most_recent_checkpointing_strategy | [KeepEveryKStepsAndMMostRecentCheckpointingStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | [KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy that always saves the current checkpoint, keeps the m most recent checkpoints, and retains older checkpoints at multiples of k | +| checkpoint_saving_execution | fsdp | [FSDPCheckpointSaving](../../src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py) | [FSDPCheckpointSavingConfig](../../src/modalities/config/config.py) | [CheckpointSavingExecutionABC](../../src/modalities/checkpointing/checkpoint_saving_execution.py) | FSDPCheckpointSaving class for saving checkpoints of FSDP models and optimizers. | +| checkpoint_loading | fsdp | [FSDPCheckpointLoading](../../src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py) | [FSDPCheckpointLoadingConfig](../../src/modalities/config/config.py) | [CheckpointLoadingIF](../../src/modalities/checkpointing/checkpoint_loading.py) | Component for loading FSDP checkpoints | +| checkpoint_loading | torch | [TorchCheckpointLoading](../../src/modalities/checkpointing/torch/torch_checkpoint_loading.py) | [TorchCheckpointLoadingConfig](../../src/modalities/config/config.py) | [CheckpointLoadingIF](../../src/modalities/checkpointing/checkpoint_loading.py) | Component for loading PyTorch checkpoints | ## Logging -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| progress_subscriber | dummy | [ProgressSubscriberFactory.get_dummy_progress_subscriber](../../src/modalities/logging_broker/subscriber_impl/subscriber_factory.py)| [DummyProgressSubscriberConfig](../../src/modalities/config/config.py) | [MessageSubscriberIF](../../src/modalities/logging_broker/subscriber.py) | Dummy Progress subscriber not consuming any messages| -| progress_subscriber | rich | [ProgressSubscriberFactory.get_rich_progress_subscriber](../../src/modalities/logging_broker/subscriber_impl/subscriber_factory.py)| [RichProgressSubscriberConfig](../../src/modalities/config/config.py) | [MessageSubscriberIF](../../src/modalities/logging_broker/subscriber.py) | Subscriber for writing out rich-formatted console outputs w.r.t. to training and evaluation progress | -| results_subscriber | wandb | [ProgressSubscriberFactory.get_wandb_result_subscriber](../../src/modalities/logging_broker/subscriber_impl/subscriber_factory.py)| [WandBEvaluationResultSubscriberConfig](../../src/modalities/config/config.py) | [MessageSubscriberIF](../../src/modalities/logging_broker/subscriber.py) | Subscriber for logging evaluation results to Weights and Biases | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|---------------------|-------------------|--------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------|--------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------| +| progress_subscriber | dummy | [ProgressSubscriberFactory.get_dummy_progress_subscriber](../../src/modalities/logging_broker/subscriber_impl/subscriber_factory.py) | [DummyProgressSubscriberConfig](../../src/modalities/config/config.py) | [MessageSubscriberIF](../../src/modalities/logging_broker/subscriber.py) | Dummy Progress subscriber not consuming any messages | +| progress_subscriber | rich | [ProgressSubscriberFactory.get_rich_progress_subscriber](../../src/modalities/logging_broker/subscriber_impl/subscriber_factory.py) | [RichProgressSubscriberConfig](../../src/modalities/config/config.py) | [MessageSubscriberIF](../../src/modalities/logging_broker/subscriber.py) | Subscriber for writing out rich-formatted console outputs w.r.t. to training and evaluation progress | +| results_subscriber | wandb | [ProgressSubscriberFactory.get_wandb_result_subscriber](../../src/modalities/logging_broker/subscriber_impl/subscriber_factory.py) | [WandBEvaluationResultSubscriberConfig](../../src/modalities/config/config.py) | [MessageSubscriberIF](../../src/modalities/logging_broker/subscriber.py) | Subscriber for logging evaluation results to Weights and Biases | ## Layer Norms -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| layer_norm | rms_norm | [RMSLayerNorm](../../src/modalities/models/components/layer_norms.py)| [RMSLayerNormConfig](../../src/modalities/models/components/layer_norms.py) | [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) | RMS Layer norm | -| layer_norm | layer_norm | [nn.LayerNorm](../../src/modalities/models/components/layer_norms.py)| [LayerNormConfig](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) | [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) | Layer norm | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|----------------|-------------------|-----------------------------------------------------------------------|--------------------------------------------------------------------------------------|-----------------------------------------------------------------------------|----------------| +| layer_norm | rms_norm | [RMSLayerNorm](../../src/modalities/models/components/layer_norms.py) | [RMSLayerNormConfig](../../src/modalities/models/components/layer_norms.py) | [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) | RMS Layer norm | +| layer_norm | layer_norm | [nn.LayerNorm](../../src/modalities/models/components/layer_norms.py) | [LayerNormConfig](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) | [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) | Layer norm | ## Gradient Clipping -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| gradient_clipper | fsdp | [FSDPGradientClipper](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py)| [FSDPGradientClipperConfig](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py) | [GradientClipperIF](../../src/modalities/training/gradient_clipping/gradient_clipper.py) | FSDP Gradient Clipper | -| gradient_clipper | fsdp_logging_only | [FSDPLoggingOnlyGradientClipper](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py)| [FSDPGradientClipperConfig](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py) | [GradientClipperIF](../../src/modalities/training/gradient_clipping/gradient_clipper.py) | Clipper that is responsible for logging the gradient norms without actually clipping the gradients | -| gradient_clipper | dummy | [DummyGradientClipper](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py)| [DummyGradientClipperConfig](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py) | [GradientClipperIF](../../src/modalities/training/gradient_clipping/gradient_clipper.py) | Dummy clipper that does not apply any gradient clipping. | +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|------------------|-------------------|------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------| +| gradient_clipper | fsdp | [FSDPGradientClipper](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py) | [FSDPGradientClipperConfig](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py) | [GradientClipperIF](../../src/modalities/training/gradient_clipping/gradient_clipper.py) | FSDP Gradient Clipper | +| gradient_clipper | fsdp_logging_only | [FSDPLoggingOnlyGradientClipper](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py) | [FSDPGradientClipperConfig](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py) | [GradientClipperIF](../../src/modalities/training/gradient_clipping/gradient_clipper.py) | Clipper that is responsible for logging the gradient norms without actually clipping the gradients | +| gradient_clipper | dummy | [DummyGradientClipper](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py) | [DummyGradientClipperConfig](../../src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py) | [GradientClipperIF](../../src/modalities/training/gradient_clipping/gradient_clipper.py) | Dummy clipper that does not apply any gradient clipping. | ## Number conversions -|Component type | Component Version | Implementation | Configuration | Component Interface | Description | -|---------------|--------------------|----------------|---------------|---------------------|-------------| -| number_conversion | local_num_batches_from_num_samples | [NumberConversion.get_local_num_batches_from_num_samples](../../src/modalities/utils/number_conversion.py)| [LocalNumBatchesFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of samples and number of ranks. | -| number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_local_num_batches_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [LocalNumBatchesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of tokens and number of ranks. | -| number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_num_samples_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [NumSamplesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of global samples, given the global number of tokens and sequence length | -| number_conversion | num_steps_from_num_samples | [NumberConversion.get_num_steps_from_num_samples](../../src/modalities/utils/number_conversion.py)| [NumStepsFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of samples, local micro batch size and number of ranks. | -| number_conversion | num_steps_from_num_tokens | [NumberConversion.get_num_steps_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [NumStepsFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of tokens, local micro batch size and number of ranks. | -| number_conversion | num_tokens_from_num_steps | [NumberConversion.get_num_tokens_from_num_steps](../../src/modalities/utils/number_conversion.py)| [NumTokensFromNumStepsConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of tokens from the number of steps, number of ranks, local micro batch size, global number of tokens, squence length and gradient accumulation steps | -| number_conversion | last_step_from_checkpoint_path | [NumberConversion.get_num_seen_steps_from_checkpoint_path](../../src/modalities/utils/number_conversion.py)| [NumberConversionFromCheckpointPathConfig](../../src/modalities/utils/number_conversion.py) | -- | Get the last step id from a model or checkpoint file path. | -| number_conversion | global_num_target_tokens_from_checkpoint_path | [NumberConversion.get_global_num_target_tokens_from_checkpoint_path](../../src/modalities/utils/number_conversion.py)| [NumberConversionFromCheckpointPathConfig](../../src/modalities/utils/number_conversion.py) | -- | Get the number of target tokens from a model or checkpoint file path. | -| number_conversion | num_tokens_from_packed_mem_map_dataset_continuous | [NumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous](../../src/modalities/utils/number_conversion.py)| [NumTokensFromPackedMemMapDatasetContinuousConfig](../../src/modalities/utils/number_conversion.py) | -- | Get the number of tokens stored in a [packed mem map continuous dataset](../../src/modalities/dataloader/dataset.py) from the respective dataset file path. | -| number_conversion | num_steps_from_raw_dataset_index | [NumberConversion.get_num_steps_from_raw_dataset_index](../../src/modalities/utils/number_conversion.py)| [NumStepsFromRawDatasetIndexConfig](../../src/modalities/utils/number_conversion.py) | -- | Get the number of steps partially from the raw index of a raw JSONL dataset. Requires the file path to index, number of ranks, local micro batch size and gardient accumulation steps. | \ No newline at end of file +| Component type | Component Version | Implementation | Configuration | Component Interface | Description | +|-------------------|---------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| number_conversion | local_num_batches_from_num_samples | [NumberConversion.get_local_num_batches_from_num_samples](../../src/modalities/utils/number_conversion.py) | [LocalNumBatchesFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of samples and number of ranks. | +| number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_local_num_batches_from_num_tokens](../../src/modalities/utils/number_conversion.py) | [LocalNumBatchesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of tokens and number of ranks. | +| number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_num_samples_from_num_tokens](../../src/modalities/utils/number_conversion.py) | [NumSamplesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of global samples, given the global number of tokens and sequence length | +| number_conversion | num_steps_from_num_samples | [NumberConversion.get_num_steps_from_num_samples](../../src/modalities/utils/number_conversion.py) | [NumStepsFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of samples, local micro batch size and number of ranks. | +| number_conversion | num_steps_from_num_tokens | [NumberConversion.get_num_steps_from_num_tokens](../../src/modalities/utils/number_conversion.py) | [NumStepsFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of tokens, local micro batch size and number of ranks. | +| number_conversion | num_tokens_from_num_steps | [NumberConversion.get_num_tokens_from_num_steps](../../src/modalities/utils/number_conversion.py) | [NumTokensFromNumStepsConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of tokens from the number of steps, number of ranks, local micro batch size, global number of tokens, squence length and gradient accumulation steps | +| number_conversion | last_step_from_checkpoint_path | [NumberConversion.get_num_seen_steps_from_checkpoint_path](../../src/modalities/utils/number_conversion.py) | [NumberConversionFromCheckpointPathConfig](../../src/modalities/utils/number_conversion.py) | -- | Get the last step id from a model or checkpoint file path. | +| number_conversion | global_num_target_tokens_from_checkpoint_path | [NumberConversion.get_global_num_target_tokens_from_checkpoint_path](../../src/modalities/utils/number_conversion.py) | [NumberConversionFromCheckpointPathConfig](../../src/modalities/utils/number_conversion.py) | -- | Get the number of target tokens from a model or checkpoint file path. | +| number_conversion | num_tokens_from_packed_mem_map_dataset_continuous | [NumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous](../../src/modalities/utils/number_conversion.py) | [NumTokensFromPackedMemMapDatasetContinuousConfig](../../src/modalities/utils/number_conversion.py) | -- | Get the number of tokens stored in a [packed mem map continuous dataset](../../src/modalities/dataloader/dataset.py) from the respective dataset file path. | +| number_conversion | num_steps_from_raw_dataset_index | [NumberConversion.get_num_steps_from_raw_dataset_index](../../src/modalities/utils/number_conversion.py) | [NumStepsFromRawDatasetIndexConfig](../../src/modalities/utils/number_conversion.py) | -- | Get the number of steps partially from the raw index of a raw JSONL dataset. Requires the file path to index, number of ranks, local micro batch size and gardient accumulation steps. | \ No newline at end of file diff --git a/src/modalities/checkpointing/checkpoint_saving_strategies.py b/src/modalities/checkpointing/checkpoint_saving_strategies.py index 1e5d5e16e..9a810b0e2 100644 --- a/src/modalities/checkpointing/checkpoint_saving_strategies.py +++ b/src/modalities/checkpointing/checkpoint_saving_strategies.py @@ -119,3 +119,62 @@ def get_checkpoint_instruction( """ save_current = training_progress.num_seen_steps_total % self.k == 0 return CheckpointingInstruction(save_current=save_current, checkpoints_to_delete=[]) + + +class KeepEveryKStepsAndMMostRecentCheckpointingStrategy(CheckpointSavingStrategyIF): + """Checkpointing strategy that always saves the current checkpoint, + keeps the m most recent checkpoints, and retains older checkpoints at multiples of k. + """ + + def __init__(self, k: int, num_recent_checkpoints_to_keep: int = 2): + """ + Initializes the CheckpointSavingStrategy object. + + Args: + k (int): The interval of steps to keep. + num_recent_checkpoints_to_keep (int, optional): The number of recent checkpoints to keep. + This includes all checkpoints but only the ones not divisible by k will actually be deleted. + Defaults to 2. + + Returns: + None + """ + super().__init__() + self._k = k + self._num_recent_checkpoints_to_keep = num_recent_checkpoints_to_keep + self._saved_recent_checkpoints: list[TrainingProgress] = [] + assert self._k > 0, "k must be greater than 0" + assert self._num_recent_checkpoints_to_keep >= 1, "num_recent_checkpoints_to_keep must be at least 1" + + def get_checkpoint_instruction( + self, + training_progress: TrainingProgress, + evaluation_result: dict[str, EvaluationResultBatch] | None = None, + early_stopping_criterion_fulfilled: bool = False, + ) -> CheckpointingInstruction: + """ + Returns a CheckpointingInstruction object. + + Args: + training_progress (TrainingProgress): The training progress. + evaluation_result (dict[str, EvaluationResultBatch] | None, optional): + The evaluation result. Defaults to None. + early_stopping_criterion_fulfilled (bool, optional): + Whether the early stopping criterion is fulfilled. Defaults to False. + + Returns: + CheckpointingInstruction: The checkpointing instruction object. + """ + self._saved_recent_checkpoints.append(dataclasses.replace(training_progress)) + checkpoints_to_delete, self._saved_recent_checkpoints = ( + ( + self._saved_recent_checkpoints[: -self._num_recent_checkpoints_to_keep], + self._saved_recent_checkpoints[-self._num_recent_checkpoints_to_keep :], + ) + if len(self._saved_recent_checkpoints) > self._num_recent_checkpoints_to_keep + else ([], self._saved_recent_checkpoints) + ) + # Do not delete checkpoints that are divisible by k in total training steps. + checkpoints_to_delete = [cp for cp in checkpoints_to_delete if cp.num_seen_steps_total % self._k != 0] + + return CheckpointingInstruction(save_current=True, checkpoints_to_delete=checkpoints_to_delete) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 42a19b99a..947949a55 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -92,6 +92,11 @@ class SaveKMostRecentCheckpointsStrategyConfig(BaseModel): k: Annotated[int, Field(strict=True, ge=-1)] +class KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig(BaseModel): + k: Annotated[int, Field(strict=True, gt=0)] + num_recent_checkpoints_to_keep: Annotated[int, Field(strict=True, ge=1)] = 2 + + class TorchCheckpointLoadingConfig(BaseModel): device: PydanticPytorchDeviceType precision: Optional[PrecisionEnum] = None diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 26df9b432..77b24f07f 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -10,6 +10,7 @@ from modalities.checkpointing.checkpoint_saving import CheckpointSaving from modalities.checkpointing.checkpoint_saving_strategies import ( + KeepEveryKStepsAndMMostRecentCheckpointingStrategy, SaveEveryKStepsCheckpointingStrategy, SaveKMostRecentCheckpointsStrategy, ) @@ -47,6 +48,7 @@ GPT2LLMCollateFnConfig, GPT2MFUCalculatorConfig, GPT2ModelTPConfig, + KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig, LinearLRSchedulerConfig, LinearWarmupCosineAnnealingLRSchedulerConfig, LLMDataLoaderConfig, @@ -353,6 +355,12 @@ class ComponentEntity: SaveKMostRecentCheckpointsStrategy, SaveKMostRecentCheckpointsStrategyConfig, ), + ComponentEntity( + "checkpoint_saving_strategy", + "keep_every_k_steps_and_m_most_recent_checkpointing_strategy", + KeepEveryKStepsAndMMostRecentCheckpointingStrategy, + KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig, + ), # checkpoint saving execution ComponentEntity("checkpoint_saving_execution", "fsdp1", FSDP1CheckpointSaving, FSDP1CheckpointSavingConfig), ComponentEntity("checkpoint_saving_execution", "dcp", DCPCheckpointSaving, DCPCheckpointSavingConfig), diff --git a/tests/checkpointing/test_checkpoint_strategies.py b/tests/checkpointing/test_checkpoint_strategies.py index ddf8a5ba0..2b3f43d59 100644 --- a/tests/checkpointing/test_checkpoint_strategies.py +++ b/tests/checkpointing/test_checkpoint_strategies.py @@ -1,6 +1,12 @@ +import dataclasses + import pytest -from modalities.checkpointing.checkpoint_saving_strategies import SaveKMostRecentCheckpointsStrategy +from modalities.checkpointing.checkpoint_saving_instruction import CheckpointingInstruction +from modalities.checkpointing.checkpoint_saving_strategies import ( + KeepEveryKStepsAndMMostRecentCheckpointingStrategy, + SaveKMostRecentCheckpointsStrategy, +) from modalities.training.training_progress import TrainingProgress @@ -43,3 +49,112 @@ def test_checkpoint_strategy_k( if k != 0 and save_current: training_progress.num_seen_steps_current_run = 100 assert checkpoint_strategy.saved_step_checkpoints[0].num_seen_steps_current_run == num_seen_steps_current_run + + +@pytest.mark.parametrize( + "k, num_recent_checkpoints_to_keep, num_steps, num_seen_steps_previous_run, num_seen_tokens_previous_run", + [ + (3, 2, 11, 0, 0), + (2, 1, 10, 2, 4), + (4, 3, 15, 3, 6), + ], +) +def test_keep_every_k_strategy_has_no_unexpected_checkpoints( + k: int, + num_recent_checkpoints_to_keep: int, + num_steps: int, + num_seen_steps_previous_run: int, + num_seen_tokens_previous_run: int, +) -> None: + checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy( + k=k, num_recent_checkpoints_to_keep=num_recent_checkpoints_to_keep + ) + training_progress = TrainingProgress( + num_seen_steps_current_run=0, + num_seen_tokens_current_run=0, + num_target_steps=20, + num_target_tokens=40, + num_seen_steps_previous_run=num_seen_steps_previous_run, + num_seen_tokens_previous_run=num_seen_tokens_previous_run, + ) + + # Simulate training progress and checkpointing + simulator = _CheckpointSavingSimulator() + for step in range(num_steps + 1): + training_progress.num_seen_steps_current_run = step + checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress) + simulator.simulate_training_step(training_progress, checkpoint_instruction) + + for ckpt in simulator.saved_checkpoints: + # Check that only checkpoints that are divisible by k or the most recent ones are kept. + last_checkpoints = set( + range( + num_seen_steps_previous_run + num_steps - num_recent_checkpoints_to_keep + 1, + num_seen_steps_previous_run + num_steps + 1, + ) + ) + assert ckpt.num_seen_steps_total % k == 0 or ckpt.num_seen_steps_total in last_checkpoints + + +@pytest.mark.parametrize( + "k, num_recent_checkpoints_to_keep, num_steps, num_seen_steps_previous_run, num_seen_tokens_previous_run", + [ + (3, 2, 11, 0, 0), + (2, 1, 10, 2, 4), + (4, 3, 15, 3, 6), + ], +) +def test_keep_every_k_strategy_has_no_unexpected_deletions( + k: int, + num_recent_checkpoints_to_keep: int, + num_steps: int, + num_seen_steps_previous_run: int, + num_seen_tokens_previous_run: int, +) -> None: + checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy( + k=k, num_recent_checkpoints_to_keep=num_recent_checkpoints_to_keep + ) + training_progress = TrainingProgress( + num_seen_steps_current_run=0, + num_seen_tokens_current_run=0, + num_target_steps=20, + num_target_tokens=40, + num_seen_steps_previous_run=num_seen_steps_previous_run, + num_seen_tokens_previous_run=num_seen_tokens_previous_run, + ) + + # Simulate training progress and checkpointing + simulator = _CheckpointSavingSimulator() + for step in range(1, num_steps + 1): + training_progress.num_seen_steps_current_run = step + checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress) + simulator.simulate_training_step(training_progress, checkpoint_instruction) + + for i in range(num_seen_steps_previous_run + 1, num_seen_steps_previous_run + num_steps + 1): + # Check that checkpoints that are divisible by k or the most recent ones are not deleted. + if i % k == 0 or i > num_seen_steps_previous_run + num_steps - num_recent_checkpoints_to_keep: + assert any(ckpt.num_seen_steps_total == i for ckpt in simulator.saved_checkpoints) + + +def test_keep_every_k_steps_checkpointing_strategy_invalid_arguments() -> None: + with pytest.raises(AssertionError): + KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=0, num_recent_checkpoints_to_keep=1) + with pytest.raises(AssertionError): + KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=-1, num_recent_checkpoints_to_keep=1) + with pytest.raises(AssertionError): + KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=2, num_recent_checkpoints_to_keep=0) + with pytest.raises(AssertionError): + KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=2, num_recent_checkpoints_to_keep=-1) + + +class _CheckpointSavingSimulator: + def __init__(self): + self.saved_checkpoints: list[TrainingProgress] = [] + + def simulate_training_step( + self, training_progress: TrainingProgress, ckpt_instruction: CheckpointingInstruction + ) -> None: + if ckpt_instruction.save_current: + self.saved_checkpoints.append(dataclasses.replace(training_progress)) + for checkpoint_to_delete in ckpt_instruction.checkpoints_to_delete: + self.saved_checkpoints = [cp for cp in self.saved_checkpoints if cp != checkpoint_to_delete]