diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 041199f45..e3780ec40 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,9 +15,17 @@ jobs: with: python-version: '3.10' cache: 'pip' - - run: pip install --upgrade pip + - name: Create and activate venv for uv + run: | + python -m venv .venv + echo "VIRTUAL_ENV=$PWD/.venv" >> $GITHUB_ENV + echo "$PWD/.venv/bin" >> $GITHUB_PATH + - run: pip install --upgrade pip uv # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: | + uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + # Ensure PyTorch installs CPU wheels (with stubs) so pytype can resolve torch.nn + uv pip install --index-url https://download.pytorch.org/whl/cpu 'torch==2.1.1' 'torchvision==0.16.1' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. diff --git a/.gitignore b/.gitignore index 4d452573e..0a28a1119 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,4 @@ bazel-* # Emacs *~ +run_specific_test.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a1e66b4b..a4a6fb84a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,16 +14,19 @@ repos: - id: black name: black entry: black - language: system + language: python + additional_dependencies: ['black==23.12.0'] types: [python] - id: isort name: isort entry: isort - language: system + language: python + additional_dependencies: ['isort==5.13.2'] types: [python] - id: pylint name: pylint entry: pylint args: ['--msg-template="{abspath}:{line}: [{msg_id}({symbol}), {obj}] {msg}"'] - language: system + language: python + additional_dependencies: ['pylint==3.3.7'] types: [python] diff --git a/axlearn/audio/adapter.py b/axlearn/audio/adapter.py new file mode 100644 index 000000000..a5d2c459e --- /dev/null +++ b/axlearn/audio/adapter.py @@ -0,0 +1,246 @@ +# Copyright © 2023 Apple Inc. + +"""Audio model adapter for efficient fine-tuning.""" + +from typing import Optional + +import jax + +from axlearn.common.base_layer import BaseLayer +from axlearn.common.config import REQUIRED, Required, config_class +from axlearn.common.layers import BatchNorm, LayerNorm, Linear +from axlearn.common.module import Module +from axlearn.common.module import functional as F +from axlearn.common.param_init import DefaultInitializer, WeightInitializer + + +class AudioModelAdapter(BaseLayer): + """Adapter layer for efficient fine-tuning of audio models.""" + + @config_class + class Config(BaseLayer.Config): + """Configures AudioModelAdapter.""" + + # Input feature dimension. + input_dim: Required[int] = REQUIRED + # Bottleneck dimension (typically much smaller than input_dim). + bottleneck_dim: Required[int] = REQUIRED + # Whether to apply layer normalization before the adapter. + use_layer_norm: bool = True + # Whether to apply batch normalization in the adapter. + use_batch_norm: bool = False + # Scaling factor for the adapter output. + adapter_scale: float = 1.0 + # Activation function to use. + activation: str = "relu" + # Whether to add a residual connection. + residual: bool = True + + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg = self.config + + # Initialize with small weights to make adapter less disruptive initially + weight_init = WeightInitializer.default_config().set( + distribution="normal", + fan="fan_in", + scale=0.01, + ) + + bias_init = WeightInitializer.default_config().set( + distribution="normal", + fan=None, + scale=0.01, + ) + + param_init = DefaultInitializer.default_config().set( + init_by_param_name={ + ".*weight": weight_init, + ".*bias": bias_init, + }, + ) + + # Down projection to bottleneck dimension + self._add_child( + "down_proj", + Linear.default_config().set( + input_dim=cfg.input_dim, + output_dim=cfg.bottleneck_dim, + bias=True, + param_init=param_init, + ), + ) + + # Optional batch normalization + if cfg.use_batch_norm: + self._add_child( + "batch_norm", + BatchNorm.default_config().set( + input_dim=cfg.bottleneck_dim, + decay=0.9, + ), + ) + + # Up projection back to input dimension + self._add_child( + "up_proj", + Linear.default_config().set( + input_dim=cfg.bottleneck_dim, + output_dim=cfg.input_dim, + bias=True, + param_init=param_init, + ), + ) + + # Optional layer normalization + if cfg.use_layer_norm: + self._add_child( + "layer_norm", + LayerNorm.default_config().set( + input_dim=cfg.input_dim, + ), + ) + + def forward(self, inputs, **_kwargs): + """Apply the adapter transformation. + + Args: + inputs: Input tensor of shape [batch_size, seq_len, input_dim]. + **_kwargs: Additional keyword arguments (unused, kept for API compatibility). + + Returns: + Tensor of the same shape as inputs. + """ + cfg = self.config + residual = inputs + + # Apply layer normalization if specified + x = inputs + if cfg.use_layer_norm: + x = self.layer_norm(x) + + # Down projection + x = self.down_proj(x) + + # Apply batch normalization if specified + if cfg.use_batch_norm: + # BatchNorm uses is_training from context automatically + x = self.batch_norm(x) + + # Activation + if cfg.activation == "relu": + x = jax.nn.relu(x) + elif cfg.activation == "gelu": + x = jax.nn.gelu(x) + + # Up projection + x = self.up_proj(x) + + # Scale the output + if cfg.adapter_scale != 1.0: + x = x * cfg.adapter_scale + + # Add residual connection if specified + if cfg.residual: + x = x + residual + + return x + + +class ASRModelAdapter(BaseLayer): + """Adapter for Automatic Speech Recognition (ASR) models.""" + + @config_class + class Config(BaseLayer.Config): + """Configures ASRModelAdapter.""" + + # Feature dimension of the encoder. + encoder_dim: Required[int] = REQUIRED + # Bottleneck dimension for encoder adapters. + encoder_bottleneck_dim: Required[int] = REQUIRED + # Feature dimension of the decoder. + decoder_dim: Optional[int] = None + # Bottleneck dimension for decoder adapters. + decoder_bottleneck_dim: Optional[int] = None + # Whether to add adapters to the encoder. + adapt_encoder: bool = True + # Whether to add adapters to the decoder. + adapt_decoder: bool = False + # Adapter configuration. + adapter: AudioModelAdapter.Config = AudioModelAdapter.default_config() + + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg = self.config + + if cfg.adapt_encoder: + self._add_child( + "encoder_adapter", + cfg.adapter.clone( + input_dim=cfg.encoder_dim, + bottleneck_dim=cfg.encoder_bottleneck_dim, + ), + ) + + if ( + cfg.adapt_decoder + and cfg.decoder_dim is not None + and cfg.decoder_bottleneck_dim is not None + ): + self._add_child( + "decoder_adapter", + cfg.adapter.clone( + input_dim=cfg.decoder_dim, + bottleneck_dim=cfg.decoder_bottleneck_dim, + ), + ) + + def adapt_encoder_features(self, features, *, is_training=False, prng_key, state): + """Apply adaptation to encoder features. + + Args: + features: Encoder features to adapt. + is_training: Whether the model is in training mode. + prng_key: PRNG key for stochastic operations. + state: State for the adapter. + + Returns: + Adapted encoder features. + """ + cfg = self.config + if not cfg.adapt_encoder: + return features + + outputs, _ = F( + self.encoder_adapter, + inputs=(features,), + is_training=is_training, + prng_key=prng_key, + state=state["encoder_adapter"], + ) + return outputs + + def adapt_decoder_features(self, features, *, is_training=False, prng_key, state): + """Apply adaptation to decoder features. + + Args: + features: Decoder features to adapt. + is_training: Whether the model is in training mode. + prng_key: PRNG key for stochastic operations. + state: State for the adapter. + + Returns: + Adapted decoder features. + """ + cfg = self.config + if not cfg.adapt_decoder or not hasattr(self, "decoder_adapter"): + return features + + outputs, _ = F( + self.decoder_adapter, + inputs=(features,), + is_training=is_training, + prng_key=prng_key, + state=state["decoder_adapter"], + ) + return outputs diff --git a/axlearn/audio/adapter_test.py b/axlearn/audio/adapter_test.py new file mode 100644 index 000000000..e7b50b4d6 --- /dev/null +++ b/axlearn/audio/adapter_test.py @@ -0,0 +1,379 @@ +# Copyright © 2024 Apple Inc. + +"""Tests for audio adapters.""" + +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import parameterized + +from axlearn.audio.adapter import ASRModelAdapter, AudioModelAdapter +from axlearn.common.module import functional as F +from axlearn.common.test_utils import TestCase, assert_allclose + + +class AudioModelAdapterTest(TestCase): + """Tests AudioModelAdapter.""" + + def test_forward_basic(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + self.assertTrue(jnp.isfinite(outputs).all()) + + def test_forward_with_layer_norm(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + use_layer_norm=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + + def test_forward_without_residual(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + residual=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + with self.assertRaises(AssertionError): + assert_allclose(outputs, inputs) + + def test_forward_with_scaling(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + adapter_scale = 0.5 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + adapter_scale=adapter_scale, + residual=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + + @parameterized.parameters(["relu", "gelu"]) + def test_forward_with_activation(self, activation: str): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + activation=activation, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + self.assertTrue(jnp.isfinite(outputs).all()) + + def test_parameter_counts(self): + input_dim, bottleneck_dim = 256, 64 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + use_layer_norm=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + layer_params = layer.initialize_parameters_recursively(prng_key) + + down_proj_weight = layer_params["down_proj"]["weight"] + down_proj_bias = layer_params["down_proj"]["bias"] + up_proj_weight = layer_params["up_proj"]["weight"] + up_proj_bias = layer_params["up_proj"]["bias"] + layer_norm_scale = layer_params["layer_norm"]["scale"] + layer_norm_bias = layer_params["layer_norm"]["bias"] + + self.assertEqual(down_proj_weight.shape, (input_dim, bottleneck_dim)) + self.assertEqual(down_proj_bias.shape, (bottleneck_dim,)) + self.assertEqual(up_proj_weight.shape, (bottleneck_dim, input_dim)) + self.assertEqual(up_proj_bias.shape, (input_dim,)) + self.assertEqual(layer_norm_scale.shape, (input_dim,)) + self.assertEqual(layer_norm_bias.shape, (input_dim,)) + + total_params = np.prod(down_proj_weight.shape) + total_params += np.prod(down_proj_bias.shape) + total_params += np.prod(up_proj_weight.shape) + total_params += np.prod(up_proj_bias.shape) + total_params += np.prod(layer_norm_scale.shape) + total_params += np.prod(layer_norm_bias.shape) + + self.assertEqual(total_params, 33600) + + @parameterized.parameters([True, False]) + def test_training_vs_eval_mode(self, is_training: bool): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=is_training, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + + +class ASRModelAdapterTest(TestCase): + """Tests ASRModelAdapter.""" + + def test_encoder_adapter_only(self): + encoder_dim = 256 + encoder_bottleneck_dim = 64 + batch_size, seq_len = 4, 100 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=encoder_bottleneck_dim, + adapt_encoder=True, + adapt_decoder=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim)) + + adapted_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_features.shape, encoder_features.shape) + + def test_decoder_adapter_only(self): + decoder_dim = 256 + decoder_bottleneck_dim = 64 + batch_size, seq_len = 4, 50 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=128, + encoder_bottleneck_dim=32, + decoder_dim=decoder_dim, + decoder_bottleneck_dim=decoder_bottleneck_dim, + adapt_encoder=False, + adapt_decoder=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + decoder_features = jax.random.normal(input_key, (batch_size, seq_len, decoder_dim)) + + adapted_features = layer.adapt_decoder_features( + decoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_features.shape, decoder_features.shape) + + def test_both_encoders_and_decoders(self): + encoder_dim, encoder_bottleneck_dim = 256, 64 + decoder_dim, decoder_bottleneck_dim = 256, 64 + batch_size, enc_seq_len, dec_seq_len = 4, 100, 50 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=encoder_bottleneck_dim, + decoder_dim=decoder_dim, + decoder_bottleneck_dim=decoder_bottleneck_dim, + adapt_encoder=True, + adapt_decoder=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key1, input_key2 = jax.random.split(prng_key, num=4) + layer_params = layer.initialize_parameters_recursively(init_key) + + encoder_features = jax.random.normal(input_key1, (batch_size, enc_seq_len, encoder_dim)) + decoder_features = jax.random.normal(input_key2, (batch_size, dec_seq_len, decoder_dim)) + + adapted_enc_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + adapted_dec_features = layer.adapt_decoder_features( + decoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_enc_features.shape, encoder_features.shape) + self.assertEqual(adapted_dec_features.shape, decoder_features.shape) + + def test_no_adaptation(self): + encoder_dim = 256 + batch_size, seq_len = 4, 100 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=64, + adapt_encoder=False, + adapt_decoder=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim)) + + adapted_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + assert_allclose(adapted_features, encoder_features) + + def test_direct_call_fallback(self): + encoder_dim = 256 + batch_size, seq_len = 4, 100 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=64, + adapt_encoder=True, + adapt_decoder=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim)) + + adapted_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_features.shape, encoder_features.shape) diff --git a/axlearn/common/adapter_torch.py b/axlearn/common/adapter_torch.py index b6b2620da..0faaec8f0 100644 --- a/axlearn/common/adapter_torch.py +++ b/axlearn/common/adapter_torch.py @@ -179,6 +179,7 @@ def _axlearn_weight_mapper(self, weight_name: str, weight: torch.Tensor) -> torc class LayerNorm(nn.LayerNorm, TorchModule): + # pylint: disable=useless-parent-delegation # maintains interface compatibility def __init__(self, *args, eps=1e-6, **kwargs): super().__init__(*args, eps=eps, **kwargs) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 689dc65ae..8cb4d2143 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -111,7 +111,7 @@ def _key_value_iterator_indices(block_mask_map: np.ndarray) -> Tuple[Tensor, Ten return jnp.asarray(index_offset), jnp.asarray(index_offset_size) -def _mha_forward_kernel( +def _mha_forward_kernel( # pylint: disable=too-many-positional-arguments q_ref, k_ref, v_ref, diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index ba9e3d995..c33fa96b3 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -409,7 +409,7 @@ def forward( # We need to jit a function before shard_mapping it. @jax.jit - def jit_mamba_scan(x, a, b, c, delta, d): + def jit_mamba_scan(x, a, b, c, delta, d): # pylint: disable=too-many-positional-arguments y = compute_mamba_scan( # [batch_size, seq_len, inner_dim] x, a, @@ -459,7 +459,8 @@ def default_mamba_dim_to_partition_specs( the Pallas-based Mamba implementation. The inner dimension is sharded over the default tensor-parallel axis name if present, - and the the batch is sharded over the remainder of the axes. + the sequence dimension is sharded over the sequence-parallel axis name if present, + and the batch is sharded over the remainder of the axes. Args: mesh_axis_names: Mesh axis names. @@ -467,13 +468,14 @@ def default_mamba_dim_to_partition_specs( Returns: A dictionary keyed by Mamba tensor dims with partition spec values. """ - batch_axis_names = tuple(el for el in mesh_axis_names if el != "model") + batch_axis_names = tuple(el for el in mesh_axis_names if el not in ("model", "seq")) tp_axis_name = "model" if "model" in mesh_axis_names else None + seq_axis_name = "seq" if "seq" in mesh_axis_names else None - # TODO(swiseman): support sequence parallelism. - x_spec = PartitionSpec(batch_axis_names, None, tp_axis_name) + # Support sequence parallelism by sharding the sequence dimension (middle dim in btd/bts). + x_spec = PartitionSpec(batch_axis_names, seq_axis_name, tp_axis_name) a_spec = PartitionSpec(None, tp_axis_name) - b_spec = PartitionSpec(batch_axis_names, None, None) + b_spec = PartitionSpec(batch_axis_names, seq_axis_name, None) d_spec = PartitionSpec(None, tp_axis_name) partition_specs = {"btd": x_spec, "sd": a_spec, "bts": b_spec, "1d": d_spec} return partition_specs @@ -481,12 +483,13 @@ def default_mamba_dim_to_partition_specs( def default_output_partition_spec( mesh_axis_names: Sequence[str], -) -> dict[str, PartitionSpec]: +) -> PartitionSpec: """Builds a default output partition spec for the shard_mapped Pallas-based Mamba implementation. The inner dimension is sharded over the default tensor-parallel axis name if present, - and the the batch is sharded over the remainder of the axes. + the sequence dimension is sharded over the sequence-parallel axis name if present, + and the batch is sharded over the remainder of the axes. Args: mesh_axis_names: Mesh axis names. @@ -494,10 +497,11 @@ def default_output_partition_spec( Returns: A PartitionSpec. """ - batch_axis_names = tuple(el for el in mesh_axis_names if el != "model") + batch_axis_names = tuple(el for el in mesh_axis_names if el not in ("model", "seq")) tp_axis_name = "model" if "model" in mesh_axis_names else None - # TODO(swiseman): support sequence parallelism. - return PartitionSpec(batch_axis_names, None, tp_axis_name) + seq_axis_name = "seq" if "seq" in mesh_axis_names else None + # Support sequence parallelism by sharding the sequence dimension. + return PartitionSpec(batch_axis_names, seq_axis_name, tp_axis_name) def _at_least_float32(x: Tensor) -> Tensor: diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 0bef3b0dd..99f905699 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -15,6 +15,7 @@ """Tests Mamba/Mamba2 and Jamba implementations.""" +# pytype: disable=module-attr import math from typing import Optional @@ -23,6 +24,7 @@ import numpy as np import pytest import torch +import torch.nn.functional as F_torch from absl.testing import parameterized from jax._src.mesh import ResourceEnv, thread_resources from jax.experimental import mesh_utils @@ -47,6 +49,8 @@ RepeatedSSMLayer, StackedMixedSSMTransformerLayer, StackedSSMLayer, + default_mamba_dim_to_partition_specs, + default_output_partition_spec, ) from axlearn.common.ssm_kernels.ssd_kernels import ssd from axlearn.common.test_utils import TestCase, assert_allclose, set_threefry_partitionable @@ -233,6 +237,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding else: + # pylint: disable=not-callable # pad is callable, false positive conv_state = torch.nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) @@ -253,7 +258,8 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] - discrete_time_step = torch.nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] + # pylint: disable=not-callable # softplus is callable, false positive + discrete_time_step = F_torch.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] @@ -1240,7 +1246,7 @@ def teardown_class(cls): dtype=[jnp.float32, jnp.bfloat16], ) def forward( - self, input_dim: int, state_dim: int, num_heads: int, num_groups: int, dtype: jnp.dtype + self, input_dim: int, *, state_dim: int, num_heads: int, num_groups: int, dtype: jnp.dtype ): mamba2block_cfg = JambaMamba2Block.default_config().set( name="test", @@ -1283,6 +1289,7 @@ def extend_step( self, batch_size: int, input_dim: int, + *, seq_len: int, state_dim: int, num_heads: int, @@ -1352,6 +1359,7 @@ def test_prefill_states( self, batch_size: int, input_dim: int, + *, seq_len: int, state_dim: int, num_heads: int, @@ -1549,3 +1557,68 @@ def _j2t(param): torch_output_np = torch_output.cpu().detach().numpy() assert_allclose(torch_output_np, jax_output_np, atol=1e-2, rtol=1e-2) + + +class PartitionSpecTest(TestCase): + """Tests for Mamba partition spec helper functions.""" + + def test_default_mamba_dim_to_partition_specs_without_seq(self): + """Test partition specs without sequence parallelism.""" + mesh_axis_names = ("data", "fsdp", "model") + specs = default_mamba_dim_to_partition_specs(mesh_axis_names) + + # batch should be sharded over data and fsdp + # sequence should not be sharded (None) + # inner dim should be sharded over model + self.assertEqual(specs["btd"], PartitionSpec(("data", "fsdp"), None, "model")) + self.assertEqual(specs["sd"], PartitionSpec(None, "model")) + self.assertEqual(specs["bts"], PartitionSpec(("data", "fsdp"), None, None)) + self.assertEqual(specs["1d"], PartitionSpec(None, "model")) + + def test_default_mamba_dim_to_partition_specs_with_seq(self): + """Test partition specs with sequence parallelism enabled.""" + mesh_axis_names = ("data", "fsdp", "seq", "model") + specs = default_mamba_dim_to_partition_specs(mesh_axis_names) + + # batch should be sharded over data and fsdp (not seq or model) + # sequence should be sharded over seq + # inner dim should be sharded over model + self.assertEqual(specs["btd"], PartitionSpec(("data", "fsdp"), "seq", "model")) + self.assertEqual(specs["sd"], PartitionSpec(None, "model")) + self.assertEqual(specs["bts"], PartitionSpec(("data", "fsdp"), "seq", None)) + self.assertEqual(specs["1d"], PartitionSpec(None, "model")) + + def test_default_mamba_dim_to_partition_specs_only_seq(self): + """Test partition specs with only sequence axis.""" + mesh_axis_names = ("seq",) + specs = default_mamba_dim_to_partition_specs(mesh_axis_names) + + # Only seq parallelism, no batch or model sharding + self.assertEqual(specs["btd"], PartitionSpec((), "seq", None)) + self.assertEqual(specs["sd"], PartitionSpec(None, None)) + self.assertEqual(specs["bts"], PartitionSpec((), "seq", None)) + self.assertEqual(specs["1d"], PartitionSpec(None, None)) + + def test_default_output_partition_spec_without_seq(self): + """Test output partition spec without sequence parallelism.""" + mesh_axis_names = ("data", "fsdp", "model") + spec = default_output_partition_spec(mesh_axis_names) + + # batch over data and fsdp, no seq sharding, inner dim over model + self.assertEqual(spec, PartitionSpec(("data", "fsdp"), None, "model")) + + def test_default_output_partition_spec_with_seq(self): + """Test output partition spec with sequence parallelism enabled.""" + mesh_axis_names = ("data", "fsdp", "seq", "model") + spec = default_output_partition_spec(mesh_axis_names) + + # batch over data and fsdp, sequence over seq, inner dim over model + self.assertEqual(spec, PartitionSpec(("data", "fsdp"), "seq", "model")) + + def test_default_output_partition_spec_minimal(self): + """Test output partition spec with minimal mesh.""" + mesh_axis_names = ("model",) + spec = default_output_partition_spec(mesh_axis_names) + + # No batch or seq sharding, only model sharding + self.assertEqual(spec, PartitionSpec((), None, "model")) diff --git a/axlearn/vision/coco_utils.py b/axlearn/vision/coco_utils.py index e86395eb7..1d9b04540 100644 --- a/axlearn/vision/coco_utils.py +++ b/axlearn/vision/coco_utils.py @@ -65,6 +65,7 @@ def __init__(self, eval_type="box", annotation_file=None, gt_dataset=None): self.dataset = gt_dataset self.createIndex() + # pylint: disable=invalid-name # matches COCO API naming convention def loadRes(self, predictions: list[dict[str, Any]]) -> coco.COCO: """Loads result file and return a result api object.