Skip to content

Commit 1fa6085

Browse files
authored
Merge pull request #421 from Modalities/modalities_profiling
Distributed and single process profiling / tracing
2 parents bacc0b9 + fdda0fb commit 1fa6085

26 files changed

Lines changed: 974 additions & 25 deletions

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "modalities"
33
version = "0.4.0"
4-
requires-python = ">=3.10,<3.13"
4+
requires-python = ">=3.10,<=3.13"
55
description = "Modalities, a PyTorch-native framework for distributed and reproducible foundation model training."
66
readme = "README.md"
77
dependencies = [
@@ -21,8 +21,10 @@ dependencies = [
2121
"click_pathlib",
2222
"jq",
2323
"class_resolver",
24+
"matplotlib",
2425
"wandb",
2526
"einops>=0.7.0",
27+
"debugpy", # For VSCode debugging support
2628
]
2729

2830
[project.urls]

src/modalities/__main__.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
from modalities.utils.benchmarking.benchmarking_utils import SweepSets, get_updated_sweep_status
3636
from modalities.utils.benchmarking.sweep_utils import SweepGenerator
3737
from modalities.utils.communication_test import run_communication_test
38+
from modalities.utils.logger_utils import get_logger
39+
from modalities.utils.profilers.modalities_profiler import ModalitiesProfilerStarter
40+
41+
logger = get_logger("__main__")
3842

3943

4044
@click.group()
@@ -680,5 +684,75 @@ def CMD_entry_point_list_remaining_runs(
680684
f.write(f"{cfg}\n")
681685

682686

687+
@main.group(name="profile")
688+
def profile():
689+
"""
690+
Collection of utilities to profile modalities.
691+
"""
692+
pass
693+
694+
695+
@profile.command(name="distributed")
696+
@click.option(
697+
"--config_file_path",
698+
type=click_pathlib.Path(exists=True),
699+
required=True,
700+
help="Path to the YAML training config file.",
701+
)
702+
@click.option(
703+
"--experiment_root_path",
704+
type=click_pathlib.Path(file_okay=False),
705+
required=True,
706+
help="Path to the experiment output directory.",
707+
)
708+
@click.option(
709+
"--num_wait_steps",
710+
type=int,
711+
default=1,
712+
show_default=True,
713+
help="Number of wait steps to skip in profiling.",
714+
)
715+
@click.option(
716+
"--num_warmup_steps",
717+
type=int,
718+
default=1,
719+
show_default=True,
720+
help="Number of warmup steps to skip in profiling. Already recording but dropping the data.",
721+
)
722+
@click.option(
723+
"--num_measurement_steps",
724+
type=int,
725+
default=3,
726+
show_default=True,
727+
help="Number of steps to measure during profiling.",
728+
)
729+
@click.option(
730+
"--profiled_ranks",
731+
type=str,
732+
default="0",
733+
help="Comma-separated list of profiled ranks (must not have spaces), e.g. --profiled_ranks '2,4,8'",
734+
)
735+
def CMD_entry_point_run_train_step_profiler(
736+
config_file_path: Path,
737+
experiment_root_path: Path,
738+
num_wait_steps: int,
739+
num_warmup_steps: int,
740+
num_measurement_steps: int,
741+
profiled_ranks: str,
742+
):
743+
"""Run train step profiler and write result to JSON if RANK=0."""
744+
profiled_ranks_list = [int(i) for i in profiled_ranks.split(",")] if profiled_ranks != "" else [0]
745+
logger.info(f"Running distributed profiling on ranks {profiled_ranks_list}")
746+
747+
ModalitiesProfilerStarter.run_distributed(
748+
config_file_path=config_file_path,
749+
num_measurement_steps=num_measurement_steps,
750+
num_wait_steps=num_wait_steps,
751+
num_warmup_steps=num_warmup_steps,
752+
experiment_root_path=experiment_root_path,
753+
profiled_ranks=profiled_ranks_list,
754+
)
755+
756+
683757
if __name__ == "__main__":
684758
main()

src/modalities/config/pydantic_if_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
2929
from modalities.utils.mfu import MFUCalculatorABC
3030
from modalities.utils.profilers.batch_generator import DatasetBatchGeneratorIF
31+
from modalities.utils.profilers.steppable_components import SteppableComponentIF
3132

3233

3334
class PydanticThirdPartyTypeIF:
@@ -88,3 +89,4 @@ def __get_pydantic_core_schema__(
8889
PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)]
8990
PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)]
9091
PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)]
92+
PydanticSteppableComponentIFType = Annotated[SteppableComponentIF, PydanticThirdPartyTypeIF(SteppableComponentIF)]

src/modalities/models/components/layer_norms.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Annotated
23

34
import torch
@@ -10,7 +11,8 @@ class RMSLayerNorm(nn.Module):
1011

1112
def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-5):
1213
"""
13-
Initializes a LayerNorm module.
14+
RMS Norm implementation.
15+
WARNING: THIS IMPLEMENTATION IS DEPCREATED! USE torch.nn.RMSNorm INSTEAD FOR BETTER PERFORMANCE!
1416
Args:
1517
ndim (int): The number of dimensions of the input tensor.
1618
bias (bool, optional): If True, adds a learnable bias to the normalized tensor. Defaults to True.
@@ -21,6 +23,7 @@ def __init__(self, ndim: int, bias: bool = True, epsilon: float = 1e-5):
2123
Returns:
2224
None
2325
"""
26+
warnings.warn("RMSLayerNorm is deprecated. Please use torch.nn.RMSNorm for better performance.", FutureWarning)
2427

2528
super().__init__()
2629
self.epsilon = epsilon

src/modalities/registry/components.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,12 @@
8282
from modalities.loss_functions import CLMCrossEntropyLoss
8383
from modalities.models.coca.coca_model import CoCa, CoCaConfig
8484
from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn
85-
from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig
85+
from modalities.models.components.layer_norms import (
86+
LayerNormConfig,
87+
PytorchRMSLayerNormConfig,
88+
RMSLayerNorm,
89+
RMSLayerNormConfig,
90+
)
8691
from modalities.models.gpt2.collator import GPT2LLMCollateFn
8792
from modalities.models.gpt2.gpt2_model import GPT2LLMConfig
8893
from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig
@@ -130,6 +135,8 @@
130135
NumTokensFromPackedMemMapDatasetContinuousConfig,
131136
)
132137
from modalities.utils.profilers.batch_generator import RandomDatasetBatchGenerator, RandomDatasetBatchGeneratorConfig
138+
from modalities.utils.profilers.steppable_component_configs import SteppableForwardPassConfig
139+
from modalities.utils.profilers.steppable_components import SteppableForwardPass
133140

134141

135142
@dataclass
@@ -326,6 +333,7 @@ class ComponentEntity:
326333
# layer norms
327334
ComponentEntity("layer_norm", "rms_norm", RMSLayerNorm, RMSLayerNormConfig),
328335
ComponentEntity("layer_norm", "layer_norm", nn.LayerNorm, LayerNormConfig),
336+
ComponentEntity("layer_norm", "rms_norm_pytorch", nn.RMSNorm, PytorchRMSLayerNormConfig),
329337
# gradient clippers
330338
ComponentEntity("gradient_clipper", "fsdp1", FSDP1GradientClipper, FSDP1GradientClipperConfig),
331339
ComponentEntity(
@@ -416,4 +424,11 @@ class ComponentEntity:
416424
NumberConversion.get_num_steps_from_raw_dataset_index,
417425
NumStepsFromRawDatasetIndexConfig,
418426
),
427+
# Profiling components
428+
ComponentEntity(
429+
"steppable_component",
430+
"forward_pass",
431+
SteppableForwardPass,
432+
SteppableForwardPassConfig,
433+
),
419434
]

src/modalities/utils/profilers/batch_generator.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,60 @@
44
from pydantic import BaseModel
55

66
from modalities.batch import DatasetBatch
7-
8-
9-
class RandomDatasetBatchGeneratorConfig(BaseModel):
10-
vocab_size: int
11-
sequence_length: int
12-
batch_size: int
7+
from modalities.config.lookup_enum import LookupEnum
138

149

1510
class DatasetBatchGeneratorIF(ABC):
1611
def get_dataset_batch(self) -> DatasetBatch:
1712
raise NotImplementedError
1813

1914

15+
class DataTypeEnum(LookupEnum):
16+
float32 = torch.float32
17+
bfloat16 = torch.bfloat16
18+
int64 = torch.int64
19+
20+
21+
class RandomDatasetBatchGeneratorConfig(BaseModel):
22+
dims: dict[str, int]
23+
data_type: DataTypeEnum
24+
min_val: int
25+
max_val: int
26+
27+
2028
class RandomDatasetBatchGenerator(DatasetBatchGeneratorIF):
21-
def __init__(self, vocab_size: int, sequence_length: int, batch_size: int):
22-
self._vocab_size = vocab_size
23-
self._sequence_length = sequence_length
24-
self._batch_size = batch_size
29+
def __init__(self, dims: dict[str, int], data_type: DataTypeEnum, min_val: int, max_val: int):
30+
self._dims = dims
31+
self._data_type = data_type
32+
self._min_val = min_val
33+
self._max_val = max_val
34+
self._device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
2535

2636
def get_dataset_batch(self) -> DatasetBatch:
37+
size = tuple(self._dims.values())
38+
if self._data_type == DataTypeEnum.int64:
39+
inputs = torch.randint(low=self._min_val, high=self._max_val, size=size, device=self._device)
40+
targets = torch.randint(low=self._min_val, high=self._max_val, size=size, device=self._device)
41+
elif self._data_type in {DataTypeEnum.float32, DataTypeEnum.bfloat16}:
42+
inputs = (
43+
torch.rand(size=size, device=self._device, dtype=self._data_type.value)
44+
* (self._max_val - self._min_val)
45+
+ self._min_val
46+
)
47+
targets = (
48+
torch.rand(size=size, device=self._device, dtype=self._data_type.value)
49+
* (self._max_val - self._min_val)
50+
+ self._min_val
51+
)
52+
else:
53+
raise ValueError(f"Unsupported data type: {self._data_type}")
54+
2755
batch = DatasetBatch(
28-
samples={"input_ids": torch.randint(0, self._vocab_size, (self._batch_size, self._sequence_length))},
29-
targets={"target_ids": torch.randint(0, self._vocab_size, (self._batch_size, self._sequence_length))},
56+
samples={
57+
"input_ids": inputs,
58+
},
59+
targets={
60+
"target_ids": targets,
61+
},
3062
)
3163
return batch

0 commit comments

Comments
 (0)