From e085a006abecec3f15e0b41934854dd3d2e8ac36 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Sat, 1 Nov 2025 00:58:56 +0530 Subject: [PATCH 01/29] Add audio model adapters and improve SSM partition specs Introduces AudioModelAdapter and ASRModelAdapter for efficient fine-tuning of audio models, along with comprehensive tests. Adds input_grain_csv_test.py for CSV/TSV input processing tests. Updates SSM partition spec helpers to support sequence parallelism and adds corresponding tests in ssm_test.py. Updates .gitignore to exclude run_specific_test.sh. --- .gitignore | 1 + axlearn/audio/adapter.py | 256 +++++++++++++++++ axlearn/audio/adapter_test.py | 372 +++++++++++++++++++++++++ axlearn/common/input_grain_csv_test.py | 257 +++++++++++++++++ axlearn/common/ssm.py | 42 +-- axlearn/common/ssm_test.py | 76 ++++- 6 files changed, 983 insertions(+), 21 deletions(-) create mode 100644 axlearn/audio/adapter.py create mode 100644 axlearn/audio/adapter_test.py create mode 100644 axlearn/common/input_grain_csv_test.py diff --git a/.gitignore b/.gitignore index 4d452573e..0a28a1119 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,4 @@ bazel-* # Emacs *~ +run_specific_test.sh diff --git a/axlearn/audio/adapter.py b/axlearn/audio/adapter.py new file mode 100644 index 000000000..0ff7f5d96 --- /dev/null +++ b/axlearn/audio/adapter.py @@ -0,0 +1,256 @@ +# Copyright © 2023 Apple Inc. + +"""Audio model adapter for efficient fine-tuning.""" + +from typing import Optional + +import jax + +from axlearn.common.base_layer import BaseLayer +from axlearn.common.config import REQUIRED, Required, config_class +from axlearn.common.layers import BatchNorm, LayerNorm, Linear +from axlearn.common.module import Module +from axlearn.common.module import functional as F +from axlearn.common.param_init import DefaultInitializer, WeightInitializer + + +class AudioModelAdapter(BaseLayer): + """Adapter layer for efficient fine-tuning of audio models.""" + + @config_class + class Config(BaseLayer.Config): + """Configures AudioModelAdapter.""" + + # Input feature dimension. + input_dim: Required[int] = REQUIRED + # Bottleneck dimension (typically much smaller than input_dim). + bottleneck_dim: Required[int] = REQUIRED + # Whether to apply layer normalization before the adapter. + use_layer_norm: bool = True + # Whether to apply batch normalization in the adapter. + use_batch_norm: bool = False + # Scaling factor for the adapter output. + adapter_scale: float = 1.0 + # Activation function to use. + activation: str = "relu" + # Whether to add a residual connection. + residual: bool = True + + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg = self.config + + # Initialize with small weights to make adapter less disruptive initially + weight_init = WeightInitializer.default_config().set( + distribution="normal", + fan="fan_in", + scale=0.01, + ) + + bias_init = WeightInitializer.default_config().set( + distribution="normal", + fan=None, + scale=0.01, + ) + + param_init = DefaultInitializer.default_config().set( + init_by_param_name={ + ".*weight": weight_init, + ".*bias": bias_init, + }, + ) + + # Down projection to bottleneck dimension + self._add_child( + "down_proj", + Linear.default_config().set( + input_dim=cfg.input_dim, + output_dim=cfg.bottleneck_dim, + bias=True, + param_init=param_init, + ), + ) + + # Optional batch normalization + if cfg.use_batch_norm: + self._add_child( + "batch_norm", + BatchNorm.default_config().set( + input_dim=cfg.bottleneck_dim, + decay=0.9, + ), + ) + + # Up projection back to input dimension + self._add_child( + "up_proj", + Linear.default_config().set( + input_dim=cfg.bottleneck_dim, + output_dim=cfg.input_dim, + bias=True, + param_init=param_init, + ), + ) + + # Optional layer normalization + if cfg.use_layer_norm: + self._add_child( + "layer_norm", + LayerNorm.default_config().set( + input_dim=cfg.input_dim, + ), + ) + + def forward(self, inputs, **_kwargs): + """Apply the adapter transformation. + + Args: + inputs: Input tensor of shape [batch_size, seq_len, input_dim]. + **_kwargs: Additional keyword arguments (unused, kept for API compatibility). + + Returns: + Tensor of the same shape as inputs. + """ + cfg = self.config + residual = inputs + + # Apply layer normalization if specified + x = inputs + if cfg.use_layer_norm: + x = self.layer_norm(x) + + # Down projection + x = self.down_proj(x) + + # Apply batch normalization if specified + if cfg.use_batch_norm: + # BatchNorm uses is_training from context automatically + x = self.batch_norm(x) + + # Activation + if cfg.activation == "relu": + x = jax.nn.relu(x) + elif cfg.activation == "gelu": + x = jax.nn.gelu(x) + + # Up projection + x = self.up_proj(x) + + # Scale the output + if cfg.adapter_scale != 1.0: + x = x * cfg.adapter_scale + + # Add residual connection if specified + if cfg.residual: + x = x + residual + + return x + + +class ASRModelAdapter(BaseLayer): + """Adapter for Automatic Speech Recognition (ASR) models.""" + + @config_class + class Config(BaseLayer.Config): + """Configures ASRModelAdapter.""" + + # Feature dimension of the encoder. + encoder_dim: Required[int] = REQUIRED + # Bottleneck dimension for encoder adapters. + encoder_bottleneck_dim: Required[int] = REQUIRED + # Feature dimension of the decoder. + decoder_dim: Optional[int] = None + # Bottleneck dimension for decoder adapters. + decoder_bottleneck_dim: Optional[int] = None + # Whether to add adapters to the encoder. + adapt_encoder: bool = True + # Whether to add adapters to the decoder. + adapt_decoder: bool = False + # Adapter configuration. + adapter: AudioModelAdapter.Config = AudioModelAdapter.default_config() + + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg = self.config + + if cfg.adapt_encoder: + self._add_child( + "encoder_adapter", + cfg.adapter.clone( + input_dim=cfg.encoder_dim, + bottleneck_dim=cfg.encoder_bottleneck_dim, + ), + ) + + if ( + cfg.adapt_decoder + and cfg.decoder_dim is not None + and cfg.decoder_bottleneck_dim is not None + ): + self._add_child( + "decoder_adapter", + cfg.adapter.clone( + input_dim=cfg.decoder_dim, + bottleneck_dim=cfg.decoder_bottleneck_dim, + ), + ) + + def adapt_encoder_features(self, features, *, is_training=False, prng_key=None, state=None): + """Apply adaptation to encoder features. + + Args: + features: Encoder features to adapt. + is_training: Whether the model is in training mode. + prng_key: PRNG key for stochastic operations. + state: State for the adapter. + + Returns: + Adapted encoder features. + """ + cfg = self.config + if not cfg.adapt_encoder: + return features + + # Use functional API if state and prng_key are provided + if state is not None and prng_key is not None: + outputs, _ = F( + self.encoder_adapter, + inputs=features, + is_training=is_training, + prng_key=prng_key, + state=state, + ) + return outputs + + # Fall back to direct call if no state/prng_key + return self.encoder_adapter(features) + + def adapt_decoder_features(self, features, *, is_training=False, prng_key=None, state=None): + """Apply adaptation to decoder features. + + Args: + features: Decoder features to adapt. + is_training: Whether the model is in training mode. + prng_key: PRNG key for stochastic operations. + state: State for the adapter. + + Returns: + Adapted decoder features. + """ + cfg = self.config + if not cfg.adapt_decoder or not hasattr(self, "decoder_adapter"): + return features + + # Use functional API if state and prng_key are provided + if state is not None and prng_key is not None: + outputs, _ = F( + self.decoder_adapter, + inputs=features, + is_training=is_training, + prng_key=prng_key, + state=state, + ) + return outputs + + # Fall back to direct call if no state/prng_key + return self.decoder_adapter(features) diff --git a/axlearn/audio/adapter_test.py b/axlearn/audio/adapter_test.py new file mode 100644 index 000000000..4b3d29668 --- /dev/null +++ b/axlearn/audio/adapter_test.py @@ -0,0 +1,372 @@ +# Copyright © 2024 Apple Inc. + +"""Tests for audio adapters.""" + +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import parameterized + +from axlearn.audio.adapter import ASRModelAdapter, AudioModelAdapter +from axlearn.common.module import functional as F +from axlearn.common.test_utils import TestCase, assert_allclose + + +class AudioModelAdapterTest(TestCase): + """Tests AudioModelAdapter.""" + + def test_forward_basic(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=inputs, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + self.assertTrue(jnp.isfinite(outputs).all()) + + def test_forward_with_layer_norm(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + use_layer_norm=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=inputs, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + + def test_forward_without_residual(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + residual=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=inputs, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + with self.assertRaises(AssertionError): + assert_allclose(outputs, inputs) + + def test_forward_with_scaling(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + adapter_scale = 0.5 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + adapter_scale=adapter_scale, + residual=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=inputs, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + + @parameterized.parameters(["relu", "gelu"]) + def test_forward_with_activation(self, activation: str): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + activation=activation, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=inputs, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + self.assertTrue(jnp.isfinite(outputs).all()) + + def test_parameter_counts(self): + input_dim, bottleneck_dim = 256, 64 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + use_layer_norm=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + layer_params = layer.initialize_parameters_recursively(prng_key) + + down_proj_weight = layer_params["down_proj"]["weight"] + down_proj_bias = layer_params["down_proj"]["bias"] + up_proj_weight = layer_params["up_proj"]["weight"] + up_proj_bias = layer_params["up_proj"]["bias"] + layer_norm_scale = layer_params["layer_norm"]["scale"] + + self.assertEqual(down_proj_weight.shape, (input_dim, bottleneck_dim)) + self.assertEqual(down_proj_bias.shape, (bottleneck_dim,)) + self.assertEqual(up_proj_weight.shape, (bottleneck_dim, input_dim)) + self.assertEqual(up_proj_bias.shape, (input_dim,)) + self.assertEqual(layer_norm_scale.shape, (input_dim,)) + + total_params = np.prod(down_proj_weight.shape) + total_params += np.prod(down_proj_bias.shape) + total_params += np.prod(up_proj_weight.shape) + total_params += np.prod(up_proj_bias.shape) + total_params += np.prod(layer_norm_scale.shape) + + self.assertEqual(total_params, 82368) + + @parameterized.parameters([True, False]) + def test_training_vs_eval_mode(self, is_training: bool): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=inputs, + is_training=is_training, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + + +class ASRModelAdapterTest(TestCase): + """Tests ASRModelAdapter.""" + + def test_encoder_adapter_only(self): + encoder_dim = 256 + encoder_bottleneck_dim = 64 + batch_size, seq_len = 4, 100 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=encoder_bottleneck_dim, + adapt_encoder=True, + adapt_decoder=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim)) + + adapted_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_features.shape, encoder_features.shape) + + def test_decoder_adapter_only(self): + decoder_dim = 256 + decoder_bottleneck_dim = 64 + batch_size, seq_len = 4, 50 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=128, + encoder_bottleneck_dim=32, + decoder_dim=decoder_dim, + decoder_bottleneck_dim=decoder_bottleneck_dim, + adapt_encoder=False, + adapt_decoder=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + decoder_features = jax.random.normal(input_key, (batch_size, seq_len, decoder_dim)) + + adapted_features = layer.adapt_decoder_features( + decoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_features.shape, decoder_features.shape) + + def test_both_encoders_and_decoders(self): + encoder_dim, encoder_bottleneck_dim = 256, 64 + decoder_dim, decoder_bottleneck_dim = 256, 64 + batch_size, enc_seq_len, dec_seq_len = 4, 100, 50 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=encoder_bottleneck_dim, + decoder_dim=decoder_dim, + decoder_bottleneck_dim=decoder_bottleneck_dim, + adapt_encoder=True, + adapt_decoder=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key1, input_key2 = jax.random.split(prng_key, num=4) + layer_params = layer.initialize_parameters_recursively(init_key) + + encoder_features = jax.random.normal(input_key1, (batch_size, enc_seq_len, encoder_dim)) + decoder_features = jax.random.normal(input_key2, (batch_size, dec_seq_len, decoder_dim)) + + adapted_enc_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + adapted_dec_features = layer.adapt_decoder_features( + decoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_enc_features.shape, encoder_features.shape) + self.assertEqual(adapted_dec_features.shape, decoder_features.shape) + + def test_no_adaptation(self): + encoder_dim = 256 + batch_size, seq_len = 4, 100 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=64, + adapt_encoder=False, + adapt_decoder=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim)) + + adapted_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + assert_allclose(adapted_features, encoder_features) + + def test_direct_call_fallback(self): + encoder_dim = 256 + batch_size, seq_len = 4, 100 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=64, + adapt_encoder=True, + adapt_decoder=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + # Initialize params (required for layer setup, but not used in this direct call test) + _ = layer.initialize_parameters_recursively(jax.random.PRNGKey(123)) + encoder_features = jax.random.normal( + jax.random.PRNGKey(456), (batch_size, seq_len, encoder_dim) + ) + + adapted_features = layer.adapt_encoder_features(encoder_features, is_training=True) + + self.assertEqual(adapted_features.shape, encoder_features.shape) diff --git a/axlearn/common/input_grain_csv_test.py b/axlearn/common/input_grain_csv_test.py new file mode 100644 index 000000000..402d36690 --- /dev/null +++ b/axlearn/common/input_grain_csv_test.py @@ -0,0 +1,257 @@ +# Copyright © 2024 Apple Inc. + +"""Tests for CSV/TSV input processing in input_grain.py.""" + +import csv +import tempfile +import unittest +from contextlib import ExitStack +from pathlib import Path + +from axlearn.common.input_grain import csv_dataset, tsv_dataset + + +class CsvDatasetTest(unittest.TestCase): + """Tests for CSV/TSV dataset functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.exit_stack = ExitStack() + self.temp_dir = self.exit_stack.enter_context(tempfile.TemporaryDirectory()) + self.temp_path = Path(self.temp_dir) + + def tearDown(self): + """Clean up test fixtures.""" + self.exit_stack.close() + + def _create_csv_file(self, filename: str, data: list[list[str]]): + """Helper to create a CSV file for testing.""" + file_path = self.temp_path / filename + with open(file_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + for row in data: + writer.writerow(row) + return str(file_path) + + def _create_tsv_file(self, filename: str, data: list[list[str]]): + """Helper to create a TSV file for testing.""" + file_path = self.temp_path / filename + with open(file_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f, delimiter="\t") + for row in data: + writer.writerow(row) + return str(file_path) + + def test_csv_dataset_with_header(self): + """Test CSV dataset with header row.""" + data = [ + ["name", "age", "city"], + ["Alice", "25", "New York"], + ["Bob", "30", "San Francisco"], + ["Charlie", "35", "Chicago"], + ] + csv_path = self._create_csv_file("test.csv", data) + + ds = csv_dataset(csv_path) + examples = list(ds) + + self.assertEqual(len(examples), 3) + self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) + self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) + self.assertEqual(examples[2], {"name": "Charlie", "age": "35", "city": "Chicago"}) + + def test_csv_dataset_without_header(self): + """Test CSV dataset without header row.""" + data = [ + ["Alice", "25", "New York"], + ["Bob", "30", "San Francisco"], + ["Charlie", "35", "Chicago"], + ] + csv_path = self._create_csv_file("test.csv", data) + + ds = csv_dataset(csv_path, has_header=False, column_names=["name", "age", "city"]) + examples = list(ds) + + self.assertEqual(len(examples), 3) + self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) + self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) + self.assertEqual(examples[2], {"name": "Charlie", "age": "35", "city": "Chicago"}) + + def test_csv_dataset_custom_column_names(self): + """Test CSV dataset with custom column names overriding header.""" + data = [ + ["old_name", "old_age", "old_city"], + ["Alice", "25", "New York"], + ["Bob", "30", "San Francisco"], + ] + csv_path = self._create_csv_file("test.csv", data) + + ds = csv_dataset(csv_path, has_header=True, column_names=["name", "age", "city"]) + examples = list(ds) + + self.assertEqual(len(examples), 2) + self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) + self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) + + def test_csv_dataset_skip_rows(self): + """Test CSV dataset with skip_rows parameter.""" + data = [ + ["name", "age", "city"], + ["# This is a comment"], + ["# Another comment"], + ["Alice", "25", "New York"], + ["Bob", "30", "San Francisco"], + ] + csv_path = self._create_csv_file("test.csv", data) + + ds = csv_dataset(csv_path, skip_rows=2) + examples = list(ds) + + self.assertEqual(len(examples), 2) + self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) + self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) + + def test_csv_dataset_multiple_files(self): + """Test CSV dataset with multiple files.""" + data1 = [ + ["name", "age"], + ["Alice", "25"], + ["Bob", "30"], + ] + data2 = [ + ["name", "age"], + ["Charlie", "35"], + ["David", "40"], + ] + + csv_path1 = self._create_csv_file("test1.csv", data1) + csv_path2 = self._create_csv_file("test2.csv", data2) + + ds = csv_dataset([csv_path1, csv_path2]) + examples = list(ds) + + self.assertEqual(len(examples), 4) + self.assertEqual(examples[0], {"name": "Alice", "age": "25"}) + self.assertEqual(examples[1], {"name": "Bob", "age": "30"}) + self.assertEqual(examples[2], {"name": "Charlie", "age": "35"}) + self.assertEqual(examples[3], {"name": "David", "age": "40"}) + + def test_csv_dataset_malformed_rows(self): + """Test CSV dataset handling of malformed rows.""" + # Create a CSV with inconsistent column counts + file_path = self.temp_path / "malformed.csv" + with open(file_path, "w", newline="", encoding="utf-8") as f: + f.write("name,age,city\n") + f.write("Alice,25,New York\n") + f.write("Bob,30\n") # Missing city + f.write("Charlie,35,Chicago,Extra\n") # Extra column + + ds = csv_dataset(str(file_path)) + examples = list(ds) + + self.assertEqual(len(examples), 3) + self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) + self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": ""}) # Padded + self.assertEqual( + examples[2], {"name": "Charlie", "age": "35", "city": "Chicago"} + ) # Truncated + + def test_tsv_dataset(self): + """Test TSV dataset functionality.""" + data = [ + ["name", "age", "city"], + ["Alice", "25", "New York"], + ["Bob", "30", "San Francisco"], + ] + tsv_path = self._create_tsv_file("test.tsv", data) + + ds = tsv_dataset(tsv_path) + examples = list(ds) + + self.assertEqual(len(examples), 2) + self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) + self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) + + def test_csv_dataset_with_seed(self): + """Test CSV dataset with seed for reproducible shuffling.""" + data = [ + ["name", "age"], + ["Alice", "25"], + ["Bob", "30"], + ["Charlie", "35"], + ["David", "40"], + ] + csv_path = self._create_csv_file("test.csv", data) + + # Create two datasets with the same seed + ds1 = csv_dataset(csv_path, seed=42).shuffle(buffer_size=10) + ds2 = csv_dataset(csv_path, seed=42).shuffle(buffer_size=10) + + examples1 = list(ds1) + examples2 = list(ds2) + + # Should be the same order due to same seed + self.assertEqual(examples1, examples2) + self.assertEqual(len(examples1), 4) + + def test_csv_dataset_indexing(self): + """Test CSV dataset supports indexing.""" + data = [ + ["name", "age"], + ["Alice", "25"], + ["Bob", "30"], + ["Charlie", "35"], + ] + csv_path = self._create_csv_file("test.csv", data) + + ds = csv_dataset(csv_path) + + # Test length + self.assertEqual(len(ds), 3) + + # Test indexing + self.assertEqual(ds[0], {"name": "Alice", "age": "25"}) + self.assertEqual(ds[1], {"name": "Bob", "age": "30"}) + self.assertEqual(ds[2], {"name": "Charlie", "age": "35"}) + + def test_csv_dataset_error_no_column_names(self): + """Test CSV dataset raises error when no column names provided and no header.""" + data = [ + ["Alice", "25", "New York"], + ["Bob", "30", "San Francisco"], + ] + csv_path = self._create_csv_file("test.csv", data) + + with self.assertRaises(ValueError) as cm: + csv_dataset(csv_path, has_header=False) + + self.assertIn("column_names must be provided", str(cm.exception)) + + def test_csv_dataset_empty_file(self): + """Test CSV dataset with empty file.""" + file_path = self.temp_path / "empty.csv" + file_path.touch() + + with self.assertRaises(StopIteration): + # Should raise StopIteration when trying to read header from empty file + csv_dataset(str(file_path)) + + def test_csv_dataset_encoding(self): + """Test CSV dataset with different encoding.""" + # Create a file with UTF-8 content + file_path = self.temp_path / "utf8.csv" + with open(file_path, "w", encoding="utf-8") as f: + f.write("name,description\n") + f.write("Alice,Café owner\n") + f.write("Bob,Naïve user\n") + + ds = csv_dataset(str(file_path), encoding="utf-8") + examples = list(ds) + + self.assertEqual(len(examples), 2) + self.assertEqual(examples[0], {"name": "Alice", "description": "Café owner"}) + self.assertEqual(examples[1], {"name": "Bob", "description": "Naïve user"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index b1c086b95..6237ba187 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -409,7 +409,7 @@ def forward( # We need to jit a function before shard_mapping it. @jax.jit - def jit_mamba_scan(x, a, b, c, delta, d): + def jit_mamba_scan(x, a, *, b, c, delta, d): y = compute_mamba_scan( # [batch_size, seq_len, inner_dim] x, a, @@ -445,7 +445,7 @@ def jit_mamba_scan(x, a, b, c, delta, d): delta = with_sharding_constraint(delta, partition_specs["btd"]) d = with_sharding_constraint(d, partition_specs["1d"]) y = with_sharding_constraint( - partitioned_mamba_scan(x, a, b, c, delta, d), + partitioned_mamba_scan(x, a, b=b, c=c, delta=delta, d=d), cfg.output_partition_spec, ) # The Pallas kernel does not return states. @@ -459,7 +459,8 @@ def default_mamba_dim_to_partition_specs( the Pallas-based Mamba implementation. The inner dimension is sharded over the default tensor-parallel axis name if present, - and the the batch is sharded over the remainder of the axes. + the sequence dimension is sharded over the sequence-parallel axis name if present, + and the batch is sharded over the remainder of the axes. Args: mesh_axis_names: Mesh axis names. @@ -467,13 +468,14 @@ def default_mamba_dim_to_partition_specs( Returns: A dictionary keyed by Mamba tensor dims with partition spec values. """ - batch_axis_names = tuple(el for el in mesh_axis_names if el != "model") + batch_axis_names = tuple(el for el in mesh_axis_names if el not in ("model", "seq")) tp_axis_name = "model" if "model" in mesh_axis_names else None + seq_axis_name = "seq" if "seq" in mesh_axis_names else None - # TODO(swiseman): support sequence parallelism. - x_spec = PartitionSpec(batch_axis_names, None, tp_axis_name) + # Support sequence parallelism by sharding the sequence dimension (middle dim in btd/bts). + x_spec = PartitionSpec(batch_axis_names, seq_axis_name, tp_axis_name) a_spec = PartitionSpec(None, tp_axis_name) - b_spec = PartitionSpec(batch_axis_names, None, None) + b_spec = PartitionSpec(batch_axis_names, seq_axis_name, None) d_spec = PartitionSpec(None, tp_axis_name) partition_specs = {"btd": x_spec, "sd": a_spec, "bts": b_spec, "1d": d_spec} return partition_specs @@ -486,7 +488,8 @@ def default_output_partition_spec( implementation. The inner dimension is sharded over the default tensor-parallel axis name if present, - and the the batch is sharded over the remainder of the axes. + the sequence dimension is sharded over the sequence-parallel axis name if present, + and the batch is sharded over the remainder of the axes. Args: mesh_axis_names: Mesh axis names. @@ -494,10 +497,11 @@ def default_output_partition_spec( Returns: A PartitionSpec. """ - batch_axis_names = tuple(el for el in mesh_axis_names if el != "model") + batch_axis_names = tuple(el for el in mesh_axis_names if el not in ("model", "seq")) tp_axis_name = "model" if "model" in mesh_axis_names else None - # TODO(swiseman): support sequence parallelism. - return PartitionSpec(batch_axis_names, None, tp_axis_name) + seq_axis_name = "seq" if "seq" in mesh_axis_names else None + # Support sequence parallelism by sharding the sequence dimension. + return PartitionSpec(batch_axis_names, seq_axis_name, tp_axis_name) def _at_least_float32(x: Tensor) -> Tensor: @@ -560,10 +564,10 @@ class Config(BaseLayer.Config): # The recurrence implementation to use for full-sequence inputs. mamba_recurrence: BaseMambaRecurrence = LinearScanMambaRecurrence.default_config() # The recurrence implementation to use for inference. - inference_mamba_recurrence: BaseMambaRecurrence = ( - LinearScanMambaRecurrence.default_config().set( - output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES - ) + inference_mamba_recurrence: ( + BaseMambaRecurrence + ) = LinearScanMambaRecurrence.default_config().set( + output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES ) class MambaOutput(NamedTuple): @@ -1768,10 +1772,10 @@ class Config(BaseLayer.Config): # The recurrence implementation to use for full-sequence inputs. ssd_recurrence: BaseSSDRecurrence = PallasSSDRecurrence.default_config() # The recurrence implementation to use for inference. - inference_mamba_recurrence: BaseSSDRecurrence = ( - LinearScanSSDRecurrence.default_config().set( - output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES - ) + inference_mamba_recurrence: ( + BaseSSDRecurrence + ) = LinearScanSSDRecurrence.default_config().set( + output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES ) class Mamba2Output(NamedTuple): diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 2bb3a7d15..d125984aa 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -23,6 +23,7 @@ import numpy as np import pytest import torch +import torch.nn.functional as F_torch from absl.testing import parameterized from jax._src.mesh import ResourceEnv, thread_resources from jax.experimental import mesh_utils @@ -47,6 +48,8 @@ RepeatedSSMLayer, StackedMixedSSMTransformerLayer, StackedSSMLayer, + default_mamba_dim_to_partition_specs, + default_output_partition_spec, ) from axlearn.common.ssm_kernels.ssd_kernels import ssd from axlearn.common.test_utils import TestCase, assert_allclose, set_threefry_partitionable @@ -75,6 +78,7 @@ def __init__( vocab_size=50280, hidden_size=768, state_size=16, + *, num_hidden_layers=32, layer_norm_epsilon=1e-5, pad_token_id=0, @@ -251,7 +255,8 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] - discrete_time_step = torch.nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] + # pylint: disable=not-callable # softplus is callable, false positive + discrete_time_step = F_torch.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] @@ -1238,7 +1243,7 @@ def teardown_class(cls): dtype=[jnp.float32, jnp.bfloat16], ) def forward( - self, input_dim: int, state_dim: int, num_heads: int, num_groups: int, dtype: jnp.dtype + self, input_dim: int, *, state_dim: int, num_heads: int, num_groups: int, dtype: jnp.dtype ): mamba2block_cfg = JambaMamba2Block.default_config().set( name="test", @@ -1281,6 +1286,7 @@ def extend_step( self, batch_size: int, input_dim: int, + *, seq_len: int, state_dim: int, num_heads: int, @@ -1350,6 +1356,7 @@ def test_prefill_states( self, batch_size: int, input_dim: int, + *, seq_len: int, state_dim: int, num_heads: int, @@ -1547,3 +1554,68 @@ def _j2t(param): torch_output_np = torch_output.cpu().detach().numpy() assert_allclose(torch_output_np, jax_output_np, atol=1e-2, rtol=1e-2) + + +class PartitionSpecTest(TestCase): + """Tests for Mamba partition spec helper functions.""" + + def test_default_mamba_dim_to_partition_specs_without_seq(self): + """Test partition specs without sequence parallelism.""" + mesh_axis_names = ("data", "fsdp", "model") + specs = default_mamba_dim_to_partition_specs(mesh_axis_names) + + # batch should be sharded over data and fsdp + # sequence should not be sharded (None) + # inner dim should be sharded over model + self.assertEqual(specs["btd"], PartitionSpec(("data", "fsdp"), None, "model")) + self.assertEqual(specs["sd"], PartitionSpec(None, "model")) + self.assertEqual(specs["bts"], PartitionSpec(("data", "fsdp"), None, None)) + self.assertEqual(specs["1d"], PartitionSpec(None, "model")) + + def test_default_mamba_dim_to_partition_specs_with_seq(self): + """Test partition specs with sequence parallelism enabled.""" + mesh_axis_names = ("data", "fsdp", "seq", "model") + specs = default_mamba_dim_to_partition_specs(mesh_axis_names) + + # batch should be sharded over data and fsdp (not seq or model) + # sequence should be sharded over seq + # inner dim should be sharded over model + self.assertEqual(specs["btd"], PartitionSpec(("data", "fsdp"), "seq", "model")) + self.assertEqual(specs["sd"], PartitionSpec(None, "model")) + self.assertEqual(specs["bts"], PartitionSpec(("data", "fsdp"), "seq", None)) + self.assertEqual(specs["1d"], PartitionSpec(None, "model")) + + def test_default_mamba_dim_to_partition_specs_only_seq(self): + """Test partition specs with only sequence axis.""" + mesh_axis_names = ("seq",) + specs = default_mamba_dim_to_partition_specs(mesh_axis_names) + + # Only seq parallelism, no batch or model sharding + self.assertEqual(specs["btd"], PartitionSpec((), "seq", None)) + self.assertEqual(specs["sd"], PartitionSpec(None, None)) + self.assertEqual(specs["bts"], PartitionSpec((), "seq", None)) + self.assertEqual(specs["1d"], PartitionSpec(None, None)) + + def test_default_output_partition_spec_without_seq(self): + """Test output partition spec without sequence parallelism.""" + mesh_axis_names = ("data", "fsdp", "model") + spec = default_output_partition_spec(mesh_axis_names) + + # batch over data and fsdp, no seq sharding, inner dim over model + self.assertEqual(spec, PartitionSpec(("data", "fsdp"), None, "model")) + + def test_default_output_partition_spec_with_seq(self): + """Test output partition spec with sequence parallelism enabled.""" + mesh_axis_names = ("data", "fsdp", "seq", "model") + spec = default_output_partition_spec(mesh_axis_names) + + # batch over data and fsdp, sequence over seq, inner dim over model + self.assertEqual(spec, PartitionSpec(("data", "fsdp"), "seq", "model")) + + def test_default_output_partition_spec_minimal(self): + """Test output partition spec with minimal mesh.""" + mesh_axis_names = ("model",) + spec = default_output_partition_spec(mesh_axis_names) + + # No batch or seq sharding, only model sharding + self.assertEqual(spec, PartitionSpec((), None, "model")) From 6617a60b5e948285fedb34984788d10928165e65 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Sun, 2 Nov 2025 22:40:48 +0530 Subject: [PATCH 02/29] Fix return type annotation for default_output_partition_spec The function returns PartitionSpec but was annotated as dict[str, PartitionSpec]. This mismatch was causing CI test failures. --- axlearn/common/ssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index 6237ba187..b299b29df 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -483,7 +483,7 @@ def default_mamba_dim_to_partition_specs( def default_output_partition_spec( mesh_axis_names: Sequence[str], -) -> dict[str, PartitionSpec]: +) -> PartitionSpec: """Builds a default output partition spec for the shard_mapped Pallas-based Mamba implementation. From 8b04152d2a088966b9a5442e39dca2ef681810aa Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Sun, 2 Nov 2025 22:53:03 +0530 Subject: [PATCH 03/29] Remove input_grain_csv_test.py - incomplete implementation The test file was added without implementing the csv_dataset and tsv_dataset functions in input_grain.py. This caused import errors in CI. --- axlearn/common/input_grain_csv_test.py | 257 ------------------------- 1 file changed, 257 deletions(-) delete mode 100644 axlearn/common/input_grain_csv_test.py diff --git a/axlearn/common/input_grain_csv_test.py b/axlearn/common/input_grain_csv_test.py deleted file mode 100644 index 402d36690..000000000 --- a/axlearn/common/input_grain_csv_test.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright © 2024 Apple Inc. - -"""Tests for CSV/TSV input processing in input_grain.py.""" - -import csv -import tempfile -import unittest -from contextlib import ExitStack -from pathlib import Path - -from axlearn.common.input_grain import csv_dataset, tsv_dataset - - -class CsvDatasetTest(unittest.TestCase): - """Tests for CSV/TSV dataset functionality.""" - - def setUp(self): - """Set up test fixtures.""" - self.exit_stack = ExitStack() - self.temp_dir = self.exit_stack.enter_context(tempfile.TemporaryDirectory()) - self.temp_path = Path(self.temp_dir) - - def tearDown(self): - """Clean up test fixtures.""" - self.exit_stack.close() - - def _create_csv_file(self, filename: str, data: list[list[str]]): - """Helper to create a CSV file for testing.""" - file_path = self.temp_path / filename - with open(file_path, "w", newline="", encoding="utf-8") as f: - writer = csv.writer(f) - for row in data: - writer.writerow(row) - return str(file_path) - - def _create_tsv_file(self, filename: str, data: list[list[str]]): - """Helper to create a TSV file for testing.""" - file_path = self.temp_path / filename - with open(file_path, "w", newline="", encoding="utf-8") as f: - writer = csv.writer(f, delimiter="\t") - for row in data: - writer.writerow(row) - return str(file_path) - - def test_csv_dataset_with_header(self): - """Test CSV dataset with header row.""" - data = [ - ["name", "age", "city"], - ["Alice", "25", "New York"], - ["Bob", "30", "San Francisco"], - ["Charlie", "35", "Chicago"], - ] - csv_path = self._create_csv_file("test.csv", data) - - ds = csv_dataset(csv_path) - examples = list(ds) - - self.assertEqual(len(examples), 3) - self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) - self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) - self.assertEqual(examples[2], {"name": "Charlie", "age": "35", "city": "Chicago"}) - - def test_csv_dataset_without_header(self): - """Test CSV dataset without header row.""" - data = [ - ["Alice", "25", "New York"], - ["Bob", "30", "San Francisco"], - ["Charlie", "35", "Chicago"], - ] - csv_path = self._create_csv_file("test.csv", data) - - ds = csv_dataset(csv_path, has_header=False, column_names=["name", "age", "city"]) - examples = list(ds) - - self.assertEqual(len(examples), 3) - self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) - self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) - self.assertEqual(examples[2], {"name": "Charlie", "age": "35", "city": "Chicago"}) - - def test_csv_dataset_custom_column_names(self): - """Test CSV dataset with custom column names overriding header.""" - data = [ - ["old_name", "old_age", "old_city"], - ["Alice", "25", "New York"], - ["Bob", "30", "San Francisco"], - ] - csv_path = self._create_csv_file("test.csv", data) - - ds = csv_dataset(csv_path, has_header=True, column_names=["name", "age", "city"]) - examples = list(ds) - - self.assertEqual(len(examples), 2) - self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) - self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) - - def test_csv_dataset_skip_rows(self): - """Test CSV dataset with skip_rows parameter.""" - data = [ - ["name", "age", "city"], - ["# This is a comment"], - ["# Another comment"], - ["Alice", "25", "New York"], - ["Bob", "30", "San Francisco"], - ] - csv_path = self._create_csv_file("test.csv", data) - - ds = csv_dataset(csv_path, skip_rows=2) - examples = list(ds) - - self.assertEqual(len(examples), 2) - self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) - self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) - - def test_csv_dataset_multiple_files(self): - """Test CSV dataset with multiple files.""" - data1 = [ - ["name", "age"], - ["Alice", "25"], - ["Bob", "30"], - ] - data2 = [ - ["name", "age"], - ["Charlie", "35"], - ["David", "40"], - ] - - csv_path1 = self._create_csv_file("test1.csv", data1) - csv_path2 = self._create_csv_file("test2.csv", data2) - - ds = csv_dataset([csv_path1, csv_path2]) - examples = list(ds) - - self.assertEqual(len(examples), 4) - self.assertEqual(examples[0], {"name": "Alice", "age": "25"}) - self.assertEqual(examples[1], {"name": "Bob", "age": "30"}) - self.assertEqual(examples[2], {"name": "Charlie", "age": "35"}) - self.assertEqual(examples[3], {"name": "David", "age": "40"}) - - def test_csv_dataset_malformed_rows(self): - """Test CSV dataset handling of malformed rows.""" - # Create a CSV with inconsistent column counts - file_path = self.temp_path / "malformed.csv" - with open(file_path, "w", newline="", encoding="utf-8") as f: - f.write("name,age,city\n") - f.write("Alice,25,New York\n") - f.write("Bob,30\n") # Missing city - f.write("Charlie,35,Chicago,Extra\n") # Extra column - - ds = csv_dataset(str(file_path)) - examples = list(ds) - - self.assertEqual(len(examples), 3) - self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) - self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": ""}) # Padded - self.assertEqual( - examples[2], {"name": "Charlie", "age": "35", "city": "Chicago"} - ) # Truncated - - def test_tsv_dataset(self): - """Test TSV dataset functionality.""" - data = [ - ["name", "age", "city"], - ["Alice", "25", "New York"], - ["Bob", "30", "San Francisco"], - ] - tsv_path = self._create_tsv_file("test.tsv", data) - - ds = tsv_dataset(tsv_path) - examples = list(ds) - - self.assertEqual(len(examples), 2) - self.assertEqual(examples[0], {"name": "Alice", "age": "25", "city": "New York"}) - self.assertEqual(examples[1], {"name": "Bob", "age": "30", "city": "San Francisco"}) - - def test_csv_dataset_with_seed(self): - """Test CSV dataset with seed for reproducible shuffling.""" - data = [ - ["name", "age"], - ["Alice", "25"], - ["Bob", "30"], - ["Charlie", "35"], - ["David", "40"], - ] - csv_path = self._create_csv_file("test.csv", data) - - # Create two datasets with the same seed - ds1 = csv_dataset(csv_path, seed=42).shuffle(buffer_size=10) - ds2 = csv_dataset(csv_path, seed=42).shuffle(buffer_size=10) - - examples1 = list(ds1) - examples2 = list(ds2) - - # Should be the same order due to same seed - self.assertEqual(examples1, examples2) - self.assertEqual(len(examples1), 4) - - def test_csv_dataset_indexing(self): - """Test CSV dataset supports indexing.""" - data = [ - ["name", "age"], - ["Alice", "25"], - ["Bob", "30"], - ["Charlie", "35"], - ] - csv_path = self._create_csv_file("test.csv", data) - - ds = csv_dataset(csv_path) - - # Test length - self.assertEqual(len(ds), 3) - - # Test indexing - self.assertEqual(ds[0], {"name": "Alice", "age": "25"}) - self.assertEqual(ds[1], {"name": "Bob", "age": "30"}) - self.assertEqual(ds[2], {"name": "Charlie", "age": "35"}) - - def test_csv_dataset_error_no_column_names(self): - """Test CSV dataset raises error when no column names provided and no header.""" - data = [ - ["Alice", "25", "New York"], - ["Bob", "30", "San Francisco"], - ] - csv_path = self._create_csv_file("test.csv", data) - - with self.assertRaises(ValueError) as cm: - csv_dataset(csv_path, has_header=False) - - self.assertIn("column_names must be provided", str(cm.exception)) - - def test_csv_dataset_empty_file(self): - """Test CSV dataset with empty file.""" - file_path = self.temp_path / "empty.csv" - file_path.touch() - - with self.assertRaises(StopIteration): - # Should raise StopIteration when trying to read header from empty file - csv_dataset(str(file_path)) - - def test_csv_dataset_encoding(self): - """Test CSV dataset with different encoding.""" - # Create a file with UTF-8 content - file_path = self.temp_path / "utf8.csv" - with open(file_path, "w", encoding="utf-8") as f: - f.write("name,description\n") - f.write("Alice,Café owner\n") - f.write("Bob,Naïve user\n") - - ds = csv_dataset(str(file_path), encoding="utf-8") - examples = list(ds) - - self.assertEqual(len(examples), 2) - self.assertEqual(examples[0], {"name": "Alice", "description": "Café owner"}) - self.assertEqual(examples[1], {"name": "Bob", "description": "Naïve user"}) - - -if __name__ == "__main__": - unittest.main() From fd94c7a685a26707dcabe80d885812d692f1d6d3 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Sun, 2 Nov 2025 23:36:09 +0530 Subject: [PATCH 04/29] Fix audio adapter test failures - Fixed TypeError by wrapping single tensor inputs in tuples for F() calls in both adapter.py and adapter_test.py - Fixed parameter count assertion by including layer_norm.bias in the count calculation --- axlearn/audio/adapter.py | 4 ++-- axlearn/audio/adapter_test.py | 17 ++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/axlearn/audio/adapter.py b/axlearn/audio/adapter.py index 0ff7f5d96..39feb17d7 100644 --- a/axlearn/audio/adapter.py +++ b/axlearn/audio/adapter.py @@ -215,7 +215,7 @@ def adapt_encoder_features(self, features, *, is_training=False, prng_key=None, if state is not None and prng_key is not None: outputs, _ = F( self.encoder_adapter, - inputs=features, + inputs=(features,), is_training=is_training, prng_key=prng_key, state=state, @@ -245,7 +245,7 @@ def adapt_decoder_features(self, features, *, is_training=False, prng_key=None, if state is not None and prng_key is not None: outputs, _ = F( self.decoder_adapter, - inputs=features, + inputs=(features,), is_training=is_training, prng_key=prng_key, state=state, diff --git a/axlearn/audio/adapter_test.py b/axlearn/audio/adapter_test.py index 4b3d29668..7c547e96c 100644 --- a/axlearn/audio/adapter_test.py +++ b/axlearn/audio/adapter_test.py @@ -33,7 +33,7 @@ def test_forward_basic(self): outputs, _ = F( layer, - inputs=inputs, + inputs=(inputs,), is_training=True, prng_key=prng_key, state=layer_params, @@ -61,7 +61,7 @@ def test_forward_with_layer_norm(self): outputs, _ = F( layer, - inputs=inputs, + inputs=(inputs,), is_training=True, prng_key=prng_key, state=layer_params, @@ -88,7 +88,7 @@ def test_forward_without_residual(self): outputs, _ = F( layer, - inputs=inputs, + inputs=(inputs,), is_training=True, prng_key=prng_key, state=layer_params, @@ -119,7 +119,7 @@ def test_forward_with_scaling(self): outputs, _ = F( layer, - inputs=inputs, + inputs=(inputs,), is_training=True, prng_key=prng_key, state=layer_params, @@ -147,7 +147,7 @@ def test_forward_with_activation(self, activation: str): outputs, _ = F( layer, - inputs=inputs, + inputs=(inputs,), is_training=True, prng_key=prng_key, state=layer_params, @@ -175,20 +175,23 @@ def test_parameter_counts(self): up_proj_weight = layer_params["up_proj"]["weight"] up_proj_bias = layer_params["up_proj"]["bias"] layer_norm_scale = layer_params["layer_norm"]["scale"] + layer_norm_bias = layer_params["layer_norm"]["bias"] self.assertEqual(down_proj_weight.shape, (input_dim, bottleneck_dim)) self.assertEqual(down_proj_bias.shape, (bottleneck_dim,)) self.assertEqual(up_proj_weight.shape, (bottleneck_dim, input_dim)) self.assertEqual(up_proj_bias.shape, (input_dim,)) self.assertEqual(layer_norm_scale.shape, (input_dim,)) + self.assertEqual(layer_norm_bias.shape, (input_dim,)) total_params = np.prod(down_proj_weight.shape) total_params += np.prod(down_proj_bias.shape) total_params += np.prod(up_proj_weight.shape) total_params += np.prod(up_proj_bias.shape) total_params += np.prod(layer_norm_scale.shape) + total_params += np.prod(layer_norm_bias.shape) - self.assertEqual(total_params, 82368) + self.assertEqual(total_params, 33664) @parameterized.parameters([True, False]) def test_training_vs_eval_mode(self, is_training: bool): @@ -209,7 +212,7 @@ def test_training_vs_eval_mode(self, is_training: bool): outputs, _ = F( layer, - inputs=inputs, + inputs=(inputs,), is_training=is_training, prng_key=prng_key, state=layer_params, From bdead97802cc10a21527d0b660348c548deda17d Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 00:05:44 +0530 Subject: [PATCH 05/29] Fix ASRModelAdapter state passing to F() and parameter count - Fixed state passing by extracting encoder_adapter and decoder_adapter from the full state dict in adapt_encoder_features and adapt_decoder_features - Fixed expected parameter count from 33664 to 33600 in test_parameter_counts --- axlearn/audio/adapter.py | 4 ++-- axlearn/audio/adapter_test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/axlearn/audio/adapter.py b/axlearn/audio/adapter.py index 39feb17d7..e981186ed 100644 --- a/axlearn/audio/adapter.py +++ b/axlearn/audio/adapter.py @@ -218,7 +218,7 @@ def adapt_encoder_features(self, features, *, is_training=False, prng_key=None, inputs=(features,), is_training=is_training, prng_key=prng_key, - state=state, + state=state["encoder_adapter"], ) return outputs @@ -248,7 +248,7 @@ def adapt_decoder_features(self, features, *, is_training=False, prng_key=None, inputs=(features,), is_training=is_training, prng_key=prng_key, - state=state, + state=state["decoder_adapter"], ) return outputs diff --git a/axlearn/audio/adapter_test.py b/axlearn/audio/adapter_test.py index 7c547e96c..3c7b62a70 100644 --- a/axlearn/audio/adapter_test.py +++ b/axlearn/audio/adapter_test.py @@ -191,7 +191,7 @@ def test_parameter_counts(self): total_params += np.prod(layer_norm_scale.shape) total_params += np.prod(layer_norm_bias.shape) - self.assertEqual(total_params, 33664) + self.assertEqual(total_params, 33600) @parameterized.parameters([True, False]) def test_training_vs_eval_mode(self, is_training: bool): From ba55210eece9a074fdc4c0887eb3a7653438697e Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 00:46:32 +0530 Subject: [PATCH 06/29] Remove fallback direct calls from ASRModelAdapter - Made prng_key and state required parameters in adapt_encoder_features and adapt_decoder_features - Removed fallback direct module calls which don't work outside invocation context - Updated test_direct_call_fallback to pass required prng_key and state parameters --- axlearn/audio/adapter.py | 46 ++++++++++++++--------------------- axlearn/audio/adapter_test.py | 16 +++++++----- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/axlearn/audio/adapter.py b/axlearn/audio/adapter.py index e981186ed..a5d2c459e 100644 --- a/axlearn/audio/adapter.py +++ b/axlearn/audio/adapter.py @@ -195,7 +195,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): ), ) - def adapt_encoder_features(self, features, *, is_training=False, prng_key=None, state=None): + def adapt_encoder_features(self, features, *, is_training=False, prng_key, state): """Apply adaptation to encoder features. Args: @@ -211,21 +211,16 @@ def adapt_encoder_features(self, features, *, is_training=False, prng_key=None, if not cfg.adapt_encoder: return features - # Use functional API if state and prng_key are provided - if state is not None and prng_key is not None: - outputs, _ = F( - self.encoder_adapter, - inputs=(features,), - is_training=is_training, - prng_key=prng_key, - state=state["encoder_adapter"], - ) - return outputs - - # Fall back to direct call if no state/prng_key - return self.encoder_adapter(features) + outputs, _ = F( + self.encoder_adapter, + inputs=(features,), + is_training=is_training, + prng_key=prng_key, + state=state["encoder_adapter"], + ) + return outputs - def adapt_decoder_features(self, features, *, is_training=False, prng_key=None, state=None): + def adapt_decoder_features(self, features, *, is_training=False, prng_key, state): """Apply adaptation to decoder features. Args: @@ -241,16 +236,11 @@ def adapt_decoder_features(self, features, *, is_training=False, prng_key=None, if not cfg.adapt_decoder or not hasattr(self, "decoder_adapter"): return features - # Use functional API if state and prng_key are provided - if state is not None and prng_key is not None: - outputs, _ = F( - self.decoder_adapter, - inputs=(features,), - is_training=is_training, - prng_key=prng_key, - state=state["decoder_adapter"], - ) - return outputs - - # Fall back to direct call if no state/prng_key - return self.decoder_adapter(features) + outputs, _ = F( + self.decoder_adapter, + inputs=(features,), + is_training=is_training, + prng_key=prng_key, + state=state["decoder_adapter"], + ) + return outputs diff --git a/axlearn/audio/adapter_test.py b/axlearn/audio/adapter_test.py index 3c7b62a70..e7b50b4d6 100644 --- a/axlearn/audio/adapter_test.py +++ b/axlearn/audio/adapter_test.py @@ -364,12 +364,16 @@ def test_direct_call_fallback(self): ) layer = cfg.set(name="test").instantiate(parent=None) - # Initialize params (required for layer setup, but not used in this direct call test) - _ = layer.initialize_parameters_recursively(jax.random.PRNGKey(123)) - encoder_features = jax.random.normal( - jax.random.PRNGKey(456), (batch_size, seq_len, encoder_dim) - ) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim)) - adapted_features = layer.adapt_encoder_features(encoder_features, is_training=True) + adapted_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) self.assertEqual(adapted_features.shape, encoder_features.shape) From 6e22d6358f1b94a9fe67ad31ec7fd1c7da6375af Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:06:49 +0530 Subject: [PATCH 07/29] Use uv pip install in pre-commit workflow - Install uv and use uv pip install to respect ml-dtypes override - Fixes dependency conflict with ml-dtypes>=0.5,<0.6 vs tensorflow<0.5.0 --- .github/workflows/pre-commit.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 041199f45..8b04d6526 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,9 +15,9 @@ jobs: with: python-version: '3.10' cache: 'pip' - - run: pip install --upgrade pip + - run: pip install --upgrade pip uv # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From 8f41b40a16b4d8387979bedebe92feb4f447d500 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:11:54 +0530 Subject: [PATCH 08/29] Fix uv pip install to use --system flag - uv pip requires either a venv or --system flag - actions/setup-python creates venv but uv doesn't auto-detect it - Use --system to install into the Python environment directly --- .github/workflows/pre-commit.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 8b04d6526..eaa01fff1 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -17,7 +17,8 @@ jobs: cache: 'pip' - run: pip install --upgrade pip uv # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: | + uv pip install --system '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From c5692811192cd944e1161770cc7d62deff4a3b37 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:19:35 +0530 Subject: [PATCH 09/29] Add pylint disable for too-many-positional-arguments in jit_mamba_scan - jit_mamba_scan needs 6 positional args for shard_map compatibility - This is a nested function within JAX jit decorator - pylint 2.17+ flags this as R0917 (too-many-positional-arguments) --- axlearn/common/ssm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index b299b29df..c33fa96b3 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -409,7 +409,7 @@ def forward( # We need to jit a function before shard_mapping it. @jax.jit - def jit_mamba_scan(x, a, *, b, c, delta, d): + def jit_mamba_scan(x, a, b, c, delta, d): # pylint: disable=too-many-positional-arguments y = compute_mamba_scan( # [batch_size, seq_len, inner_dim] x, a, @@ -445,7 +445,7 @@ def jit_mamba_scan(x, a, *, b, c, delta, d): delta = with_sharding_constraint(delta, partition_specs["btd"]) d = with_sharding_constraint(d, partition_specs["1d"]) y = with_sharding_constraint( - partitioned_mamba_scan(x, a, b=b, c=c, delta=delta, d=d), + partitioned_mamba_scan(x, a, b, c, delta, d), cfg.output_partition_spec, ) # The Pallas kernel does not return states. From b2c673eae896cde017c2d63020a6ecbe2d688fbb Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 12:08:44 +0530 Subject: [PATCH 10/29] Add pylint disable for MambaConfig.__init__ too-many-positional-arguments - MambaConfig has 24 parameters to match HuggingFace's PretrainedConfig - pylint 2.17 added R0917 check which flags this legitimate case - Add disable comment to suppress the warning --- axlearn/common/ssm_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index d125984aa..21487a133 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -73,12 +73,11 @@ class MambaConfig(PretrainedConfig): model_type = "mamba" - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, vocab_size=50280, hidden_size=768, state_size=16, - *, num_hidden_layers=32, layer_norm_epsilon=1e-5, pad_token_id=0, From 3e94b99a3701c37754a838e1066925b5a30e937a Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:29:44 +0530 Subject: [PATCH 11/29] Use uv pip install without --system flag for pre-commit - uv pip install works correctly with actions/setup-python venv - Only uv pip install --system was causing issues - This matches the Dockerfile approach which works correctly --- .github/workflows/pre-commit.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index eaa01fff1..8b04d6526 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -17,8 +17,7 @@ jobs: cache: 'pip' - run: pip install --upgrade pip uv # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: | - uv pip install --system '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From be7466e64d43fd2cce140dc38a61f1f4dd16bca3 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:43:06 +0530 Subject: [PATCH 12/29] Revert to pip install matching upstream - uv doesn't work properly with actions/setup-python in GitHub Actions - Upstream uses pip successfully with same dependency versions - This should work correctly as-is --- .github/workflows/pre-commit.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 8b04d6526..041199f45 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,9 +15,9 @@ jobs: with: python-version: '3.10' cache: 'pip' - - run: pip install --upgrade pip uv + - run: pip install --upgrade pip # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From d13700dbca88068d9698828100b632ffcccd2621 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:51:05 +0530 Subject: [PATCH 13/29] Add uv back to workflow and debug venv detection --- .github/workflows/pre-commit.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 041199f45..c4137c505 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,9 +15,14 @@ jobs: with: python-version: '3.10' cache: 'pip' - - run: pip install --upgrade pip + - run: | + echo "VIRTUAL_ENV: $VIRTUAL_ENV" + echo "Python location: ${{ env.pythonLocation }}" + which python + which uv || echo "uv not found yet" + - run: pip install --upgrade pip uv # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From b67e66c7c724dbe11245eeab88efd5469ff26f3e Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:56:30 +0530 Subject: [PATCH 14/29] Revert to pip install to match upstream --- .github/workflows/pre-commit.yml | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index c4137c505..041199f45 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,14 +15,9 @@ jobs: with: python-version: '3.10' cache: 'pip' - - run: | - echo "VIRTUAL_ENV: $VIRTUAL_ENV" - echo "Python location: ${{ env.pythonLocation }}" - which python - which uv || echo "uv not found yet" - - run: pip install --upgrade pip uv + - run: pip install --upgrade pip # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From 5d6cddfead6f63bcb941d4f1a8db83930ea13f00 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:00:19 +0530 Subject: [PATCH 15/29] Use uv with explicit VIRTUAL_ENV for ml-dtypes override - pip doesn't respect tool.uv override-dependencies - Set VIRTUAL_ENV to pythonLocation so uv detects the venv - This allows uv to honor ml-dtypes>=0.5,<0.6 override --- .github/workflows/pre-commit.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 041199f45..0cfc6864c 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,9 +15,13 @@ jobs: with: python-version: '3.10' cache: 'pip' - - run: pip install --upgrade pip + - run: | + pip install --upgrade pip uv + # Ensure VIRTUAL_ENV is set for uv to detect the venv + export VIRTUAL_ENV=${{ env.pythonLocation }} + echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From d7a9e1fecb7626e5f26fa25a5a8acd2fa63ea2be Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:08:05 +0530 Subject: [PATCH 16/29] Fix VIRTUAL_ENV export in same step as uv pip install - Environment variables don't persist across GitHub Actions steps - Export VIRTUAL_ENV in the same run block as uv pip install - This ensures uv can detect the venv and honor ml-dtypes override --- .github/workflows/pre-commit.yml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 0cfc6864c..301604d5b 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,13 +15,11 @@ jobs: with: python-version: '3.10' cache: 'pip' + - run: pip install --upgrade pip uv + # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - run: | - pip install --upgrade pip uv - # Ensure VIRTUAL_ENV is set for uv to detect the venv export VIRTUAL_ENV=${{ env.pythonLocation }} - echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV - # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From d87892e0cee57b9b41109bb604f7a7da4af544be Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:10:45 +0530 Subject: [PATCH 17/29] Use pip with legacy resolver to handle ml-dtypes conflict - Revert to pip (matching upstream) with legacy resolver - Legacy resolver can install despite ml-dtypes version conflict - This allows tensorflow 2.17.1 and jax 0.6.2 to coexist --- .github/workflows/pre-commit.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 301604d5b..b9ba5bdf1 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,11 +15,10 @@ jobs: with: python-version: '3.10' cache: 'pip' - - run: pip install --upgrade pip uv + - run: pip install --upgrade pip # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: | - export VIRTUAL_ENV=${{ env.pythonLocation }} - uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + # Using --use-deprecated=legacy-resolver to handle ml-dtypes conflict + - run: pip install --use-deprecated=legacy-resolver '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From d346600b0fcb46fe9fa8b90e7016f6b50476be5c Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:34:01 +0530 Subject: [PATCH 18/29] Fix import order for isort 7.0 and add pylint disables - isort 7.0 has stricter import ordering than 5.x - Fixed 5 files to match CI expectations - Added pylint disables for too-many-positional-arguments in gpu_attention.py - These are pre-existing code style issues, not related to our changes --- axlearn/cloud/common/config_test.py | 5 ++++- .../cloud/gcp/monitoring/tpu_device_monitor.py | 4 +++- axlearn/common/flash_attention/gpu_attention.py | 16 +++++++++------- axlearn/experiments/text/gpt/fuji.py | 4 +++- axlearn/experiments/text/gpt/gspmd.py | 4 +++- 5 files changed, 22 insertions(+), 11 deletions(-) diff --git a/axlearn/cloud/common/config_test.py b/axlearn/cloud/common/config_test.py index 74e068d28..c16cd9a31 100644 --- a/axlearn/cloud/common/config_test.py +++ b/axlearn/cloud/common/config_test.py @@ -27,7 +27,10 @@ load_configs, ) from axlearn.cloud.common.config import main as config_main -from axlearn.cloud.common.config import update_configs, write_configs_with_header +from axlearn.cloud.common.config import ( + update_configs, + write_configs_with_header, +) from axlearn.common.test_utils import TestWithTemporaryCWD, temp_chdir diff --git a/axlearn/cloud/gcp/monitoring/tpu_device_monitor.py b/axlearn/cloud/gcp/monitoring/tpu_device_monitor.py index ba5740168..dad38cfe9 100644 --- a/axlearn/cloud/gcp/monitoring/tpu_device_monitor.py +++ b/axlearn/cloud/gcp/monitoring/tpu_device_monitor.py @@ -8,7 +8,9 @@ from absl import logging from tpu_info import device -from axlearn.cloud.gcp.monitoring.tpu_client import TPU_DEVICE_PLUGIN_METRICS_SERVER_ADDR +from axlearn.cloud.gcp.monitoring.tpu_client import ( + TPU_DEVICE_PLUGIN_METRICS_SERVER_ADDR, +) from axlearn.cloud.gcp.monitoring.tpu_client import MetricV2Name as MetricName from axlearn.cloud.gcp.monitoring.tpu_client import get_chip_metrics_v2 as get_chip_metrics from axlearn.cloud.gcp.monitoring.tpu_client import ( diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 9026f3a3e..87993909e 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -38,7 +38,9 @@ import numpy as np from absl import logging from jax import lax -from jax._src.cudnn.fused_attention_stablehlo import MaskType +from jax._src.cudnn.fused_attention_stablehlo import ( + MaskType, +) from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention as cudnn_dot_product_attention, ) @@ -111,7 +113,7 @@ def _key_value_iterator_indices(block_mask_map: np.ndarray) -> Tuple[Tensor, Ten return jnp.asarray(index_offset), jnp.asarray(index_offset_size) -def _mha_forward_kernel( +def _mha_forward_kernel( # pylint: disable=too-many-positional-arguments q_ref, k_ref, v_ref, @@ -245,7 +247,7 @@ def body(start_k, carry): # pylint: disable=unused-argument @functools.partial(jax.custom_vjp, nondiff_argnums=[6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) -def flash_attention( +def flash_attention( # pylint: disable=too-many-positional-arguments query: Tensor, key: Tensor, value: Tensor, @@ -286,7 +288,7 @@ def flash_attention( # pylint: enable=unused-argument -def _flash_attention_impl( +def _flash_attention_impl( # pylint: disable=too-many-positional-arguments query: Tensor, key: Tensor, value: Tensor, @@ -429,7 +431,7 @@ def _mha_forward(*args: Any): # TODO(lezhi): Add support arbitrary per-head-dim in backward pass. -def _mha_backward_kernel_dkdv( +def _mha_backward_kernel_dkdv( # pylint: disable=too-many-positional-arguments # Inputs. q_ref, k_ref, @@ -524,7 +526,7 @@ def inner_loop_dkdv(start_q, carry): pl.store(dk_ref, (curr_k_slice, slice(None)), dk.astype(dk_ref.dtype)) -def _mha_backward_kernel_dq( +def _mha_backward_kernel_dq( # pylint: disable=too-many-positional-arguments # Inputs. q_ref, k_ref, @@ -610,7 +612,7 @@ def inner_loop_dq(start_k, carry): pl.store(dq_ref, (curr_q_slice, slice(None)), dq.astype(dq_ref.dtype)) -def _mha_backward( +def _mha_backward( # pylint: disable=too-many-positional-arguments softmax_scale: float, mask_fn: Optional[MaskFn], dropout_rate: float, diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 96b45a070..514652cfa 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -64,7 +64,9 @@ mesh_shape_from_axes, ) from axlearn.experiments.text.gpt.common import model_config as common_model_config -from axlearn.experiments.text.gpt.common import scaled_hidden_dim +from axlearn.experiments.text.gpt.common import ( + scaled_hidden_dim, +) from axlearn.experiments.trainer_config_utils import V6eFlashConfigModifier MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B", "405B") diff --git a/axlearn/experiments/text/gpt/gspmd.py b/axlearn/experiments/text/gpt/gspmd.py index ffda14958..e1bbfef4d 100644 --- a/axlearn/experiments/text/gpt/gspmd.py +++ b/axlearn/experiments/text/gpt/gspmd.py @@ -27,7 +27,9 @@ mesh_shape_from_axes, ) from axlearn.experiments.text.gpt.common import model_config as common_model_config -from axlearn.experiments.text.gpt.common import scaled_hidden_dim +from axlearn.experiments.text.gpt.common import ( + scaled_hidden_dim, +) _VOCAB_SIZE = 32 * 1024 _MAX_SEQUENCE_LENGTH = 1024 From e29f6559c4feaf614272f86c0f8762612d803612 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 17:59:26 +0530 Subject: [PATCH 19/29] Reinstall google-cloud-aiplatform after legacy-resolver install - Legacy resolver may skip or break google-cloud-aiplatform installation - Force reinstall with --no-deps to fix pytype import errors - This ensures pytype can analyze vertexai_tensorboard.py --- .github/workflows/pre-commit.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index b9ba5bdf1..3774425e8 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -18,7 +18,10 @@ jobs: - run: pip install --upgrade pip # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) # Using --use-deprecated=legacy-resolver to handle ml-dtypes conflict - - run: pip install --use-deprecated=legacy-resolver '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: | + pip install --use-deprecated=legacy-resolver '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + # Reinstall google-cloud-aiplatform to fix pytype imports (legacy-resolver may have broken it) + pip install --force-reinstall --no-deps 'google-cloud-aiplatform[tensorboard]==1.61.0' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From 81b08d473e980630d58d542a098b36efa7a02998 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 18:14:22 +0530 Subject: [PATCH 20/29] Also reinstall transformers after legacy-resolver - Legacy resolver breaks transformers package installation - Add transformers==4.51.3 to force reinstall for pytype - This fixes param_converter.py import errors --- .github/workflows/pre-commit.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 3774425e8..0467684d2 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -20,8 +20,8 @@ jobs: # Using --use-deprecated=legacy-resolver to handle ml-dtypes conflict - run: | pip install --use-deprecated=legacy-resolver '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' - # Reinstall google-cloud-aiplatform to fix pytype imports (legacy-resolver may have broken it) - pip install --force-reinstall --no-deps 'google-cloud-aiplatform[tensorboard]==1.61.0' + # Reinstall packages that legacy-resolver may have broken (needed for pytype) + pip install --force-reinstall --no-deps 'google-cloud-aiplatform[tensorboard]==1.61.0' 'transformers==4.51.3' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From 18bb7468628b9fa6c3b9c7d0afe5d29773f3631d Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 18:31:29 +0530 Subject: [PATCH 21/29] CI: use dedicated venv + uv for dependency install - Create .venv and export VIRTUAL_ENV + PATH via GITHUB_ENV/GITHUB_PATH - Install pip+uv, then install extras with uv (honors ml-dtypes override) - Remove legacy resolver and post-install hacks (aiplatform/transformers) - Ensures pytype can resolve imports reliably --- .github/workflows/pre-commit.yml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 0467684d2..b364c716f 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,13 +15,14 @@ jobs: with: python-version: '3.10' cache: 'pip' - - run: pip install --upgrade pip + - name: Create and activate venv for uv + run: | + python -m venv .venv + echo "VIRTUAL_ENV=$PWD/.venv" >> $GITHUB_ENV + echo "$PWD/.venv/bin" >> $GITHUB_PATH + - run: pip install --upgrade pip uv # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - # Using --use-deprecated=legacy-resolver to handle ml-dtypes conflict - - run: | - pip install --use-deprecated=legacy-resolver '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' - # Reinstall packages that legacy-resolver may have broken (needed for pytype) - pip install --force-reinstall --no-deps 'google-cloud-aiplatform[tensorboard]==1.61.0' 'transformers==4.51.3' + - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From a92685a63ce2a9b09ba3f50f02a586788fabac67 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 18:38:56 +0530 Subject: [PATCH 22/29] Pin isort==7.0.0 to align CI and local formatting --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 41ab5d0df..0fbf4ef7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ dev = [ "black==23.1a1", # formatting "einops==0.8.0", # for unittests only. Do not use it in core impl because it's not thread-safe. "evaluate", - "isort", # formatting + "isort==7.0.0", # formatting; pin to match CI and avoid drift "pika==1.3.2", # used by event queue "pre-commit", # local pre commit hooks "pycocotools", # COCO evaluation tools From ac298751308c3760f43e1c6b479d029b97ef2d9b Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 18:45:23 +0530 Subject: [PATCH 23/29] CI: pin isort==5.13.2 to satisfy pylint (<6) and allow uv resolution --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0fbf4ef7a..4b59c5793 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ dev = [ "black==23.1a1", # formatting "einops==0.8.0", # for unittests only. Do not use it in core impl because it's not thread-safe. "evaluate", - "isort==7.0.0", # formatting; pin to match CI and avoid drift + "isort==5.13.2", # formatting; compatible with pylint<3 (<6 requirement) "pika==1.3.2", # used by event queue "pre-commit", # local pre commit hooks "pycocotools", # COCO evaluation tools From ebcb758ecbb1125ae35a50d04f8bbf3fb68dc252 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 20:03:33 +0530 Subject: [PATCH 24/29] chore(ci): trigger CI on latest workflow/deps changes From 5f7e1421590495f33b65022e90e9aca06eaf5f04 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 22:13:01 +0530 Subject: [PATCH 25/29] Apply isort 5.13.2 formatting to match CI --- axlearn/cloud/common/config_test.py | 5 +---- axlearn/cloud/gcp/monitoring/tpu_device_monitor.py | 4 +--- axlearn/common/flash_attention/gpu_attention.py | 4 +--- axlearn/experiments/text/gpt/fuji.py | 4 +--- axlearn/experiments/text/gpt/gspmd.py | 4 +--- 5 files changed, 5 insertions(+), 16 deletions(-) diff --git a/axlearn/cloud/common/config_test.py b/axlearn/cloud/common/config_test.py index c16cd9a31..74e068d28 100644 --- a/axlearn/cloud/common/config_test.py +++ b/axlearn/cloud/common/config_test.py @@ -27,10 +27,7 @@ load_configs, ) from axlearn.cloud.common.config import main as config_main -from axlearn.cloud.common.config import ( - update_configs, - write_configs_with_header, -) +from axlearn.cloud.common.config import update_configs, write_configs_with_header from axlearn.common.test_utils import TestWithTemporaryCWD, temp_chdir diff --git a/axlearn/cloud/gcp/monitoring/tpu_device_monitor.py b/axlearn/cloud/gcp/monitoring/tpu_device_monitor.py index dad38cfe9..ba5740168 100644 --- a/axlearn/cloud/gcp/monitoring/tpu_device_monitor.py +++ b/axlearn/cloud/gcp/monitoring/tpu_device_monitor.py @@ -8,9 +8,7 @@ from absl import logging from tpu_info import device -from axlearn.cloud.gcp.monitoring.tpu_client import ( - TPU_DEVICE_PLUGIN_METRICS_SERVER_ADDR, -) +from axlearn.cloud.gcp.monitoring.tpu_client import TPU_DEVICE_PLUGIN_METRICS_SERVER_ADDR from axlearn.cloud.gcp.monitoring.tpu_client import MetricV2Name as MetricName from axlearn.cloud.gcp.monitoring.tpu_client import get_chip_metrics_v2 as get_chip_metrics from axlearn.cloud.gcp.monitoring.tpu_client import ( diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 87993909e..ab587428f 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -38,9 +38,7 @@ import numpy as np from absl import logging from jax import lax -from jax._src.cudnn.fused_attention_stablehlo import ( - MaskType, -) +from jax._src.cudnn.fused_attention_stablehlo import MaskType from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention as cudnn_dot_product_attention, ) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 514652cfa..96b45a070 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -64,9 +64,7 @@ mesh_shape_from_axes, ) from axlearn.experiments.text.gpt.common import model_config as common_model_config -from axlearn.experiments.text.gpt.common import ( - scaled_hidden_dim, -) +from axlearn.experiments.text.gpt.common import scaled_hidden_dim from axlearn.experiments.trainer_config_utils import V6eFlashConfigModifier MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B", "405B") diff --git a/axlearn/experiments/text/gpt/gspmd.py b/axlearn/experiments/text/gpt/gspmd.py index e1bbfef4d..ffda14958 100644 --- a/axlearn/experiments/text/gpt/gspmd.py +++ b/axlearn/experiments/text/gpt/gspmd.py @@ -27,9 +27,7 @@ mesh_shape_from_axes, ) from axlearn.experiments.text.gpt.common import model_config as common_model_config -from axlearn.experiments.text.gpt.common import ( - scaled_hidden_dim, -) +from axlearn.experiments.text.gpt.common import scaled_hidden_dim _VOCAB_SIZE = 32 * 1024 _MAX_SEQUENCE_LENGTH = 1024 From f8d0ca748051c2f48b9f20d7c0fdb3d3b4e5b3b9 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 22:36:22 +0530 Subject: [PATCH 26/29] CI: pin torch==2.1.1 and torchvision==0.16.1 for pytype compatibility after uv install --- .github/workflows/pre-commit.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index b364c716f..ae496ee6f 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -22,7 +22,10 @@ jobs: echo "$PWD/.venv/bin" >> $GITHUB_PATH - run: pip install --upgrade pip uv # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: | + uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + # Pin torch to a pytype-friendly version (newer torch confuses old pytype) + uv pip install 'torch==2.1.1' 'torchvision==0.16.1' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From efe139702f94596e69d525f4ab0d4503f15a600c Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 23:03:34 +0530 Subject: [PATCH 27/29] CI: install PyTorch CPU wheels via index-url so pytype sees torch.nn stubs --- .github/workflows/pre-commit.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index ae496ee6f..e3780ec40 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -24,8 +24,8 @@ jobs: # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - run: | uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' - # Pin torch to a pytype-friendly version (newer torch confuses old pytype) - uv pip install 'torch==2.1.1' 'torchvision==0.16.1' + # Ensure PyTorch installs CPU wheels (with stubs) so pytype can resolve torch.nn + uv pip install --index-url https://download.pytorch.org/whl/cpu 'torch==2.1.1' 'torchvision==0.16.1' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. From 176d39ec9a34dad6b83e45fd7d5e1e209a9dbcbd Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Mon, 3 Nov 2025 23:33:46 +0530 Subject: [PATCH 28/29] pytype: disable module-attr in ssm_test to allow torch.nn usage under pytype --- axlearn/common/ssm_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 21487a133..3a03a360a 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -15,6 +15,7 @@ """Tests Mamba/Mamba2 and Jamba implementations.""" +# pytype: disable=module-attr import math from typing import Optional From 13115d45d72140f7559e4922e7d91f1de510894a Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Fri, 21 Nov 2025 23:48:34 +0530 Subject: [PATCH 29/29] Refactor with-statement formatting and minor style fixes Refactored multiple test files to use parenthesized 'with' statements for improved readability and consistency. Added or adjusted trailing commas, improved docstring formatting, and added or updated pylint disables where appropriate. No functional changes were made. --- .pre-commit-config.yaml | 9 ++- axlearn/audio/aligner/ctc_aligner.py | 1 + axlearn/audio/evaler_asr_test.py | 14 ++--- axlearn/cloud/common/bastion_test.py | 14 ++--- axlearn/cloud/gcp/job.py | 2 - axlearn/cloud/gcp/k8s_service.py | 3 +- axlearn/cloud/gcp/measurement_test.py | 7 ++- axlearn/cloud/gcp/pathways_utils.py | 10 +-- axlearn/cloud/gcp/tpu_health_check_test.py | 28 ++++++--- axlearn/common/adapter_torch.py | 1 + axlearn/common/aot_compilation_test.py | 1 + axlearn/common/array_serialization_test.py | 62 +++++++++++-------- axlearn/common/checkpointer.py | 8 ++- axlearn/common/inference_test.py | 11 ++-- axlearn/common/input_lm.py | 4 +- axlearn/common/input_t5.py | 2 +- axlearn/common/input_text.py | 2 +- axlearn/common/layers_test.py | 2 +- axlearn/common/optimizers.py | 6 +- axlearn/common/optimizers_test.py | 4 +- axlearn/common/quantized_dot_general/utils.py | 4 +- .../kernels/linear_attention_kernels.py | 2 +- axlearn/common/rattention/rattention.py | 1 + axlearn/common/rattention/rattention_test.py | 1 + axlearn/common/rattention/utils.py | 1 + axlearn/common/ssm_kernels/ssd_kernels.py | 2 +- axlearn/common/ssm_test.py | 1 + axlearn/vision/coco_utils.py | 1 + 28 files changed, 118 insertions(+), 86 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a1e66b4b..a4a6fb84a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,16 +14,19 @@ repos: - id: black name: black entry: black - language: system + language: python + additional_dependencies: ['black==23.12.0'] types: [python] - id: isort name: isort entry: isort - language: system + language: python + additional_dependencies: ['isort==5.13.2'] types: [python] - id: pylint name: pylint entry: pylint args: ['--msg-template="{abspath}:{line}: [{msg_id}({symbol}), {obj}] {msg}"'] - language: system + language: python + additional_dependencies: ['pylint==3.3.7'] types: [python] diff --git a/axlearn/audio/aligner/ctc_aligner.py b/axlearn/audio/aligner/ctc_aligner.py index 1bf6acb58..8ae402053 100644 --- a/axlearn/audio/aligner/ctc_aligner.py +++ b/axlearn/audio/aligner/ctc_aligner.py @@ -15,6 +15,7 @@ """ + from dataclasses import asdict, dataclass from typing import Literal, NamedTuple, Optional, Tuple diff --git a/axlearn/audio/evaler_asr_test.py b/axlearn/audio/evaler_asr_test.py index e6bb72340..53aca811d 100644 --- a/axlearn/audio/evaler_asr_test.py +++ b/axlearn/audio/evaler_asr_test.py @@ -133,13 +133,13 @@ def _compute_metrics( if brevity_penalty: decode_kwargs["brevity_penalty"] = brevity_penalty - cfg: WordErrorRateMetricCalculator.Config = ( - WordErrorRateMetricCalculator.default_config().set( - vocab=config_for_class(seqio.SentencePieceVocabulary).set( - sentencepiece_model_file=vocab_file, - ), - model_method_kwargs=decode_kwargs, - ) + cfg: ( + WordErrorRateMetricCalculator.Config + ) = WordErrorRateMetricCalculator.default_config().set( + vocab=config_for_class(seqio.SentencePieceVocabulary).set( + sentencepiece_model_file=vocab_file, + ), + model_method_kwargs=decode_kwargs, ) calculator: WordErrorRateMetricCalculator = cfg.set(name="test-metric").instantiate( parent=None, model=model, model_param_partition_specs={} diff --git a/axlearn/cloud/common/bastion_test.py b/axlearn/cloud/common/bastion_test.py index b3d6255e3..5e14c9ce1 100644 --- a/axlearn/cloud/common/bastion_test.py +++ b/axlearn/cloud/common/bastion_test.py @@ -1690,10 +1690,9 @@ def test_sync_jobs_for_valid_pending_to_sudden_invalid_jobs(self): mock_validator_cfg = MockStatefulJobValidator.default_config() mock_append_to_job_history = mock.MagicMock() - with self._patch_bastion( - validator_cfg=mock_validator_cfg - ) as mock_bastion, mock.patch.object( - mock_bastion, "_append_to_job_history", mock_append_to_job_history + with ( + self._patch_bastion(validator_cfg=mock_validator_cfg) as mock_bastion, + mock.patch.object(mock_bastion, "_append_to_job_history", mock_append_to_job_history), ): os.makedirs(mock_bastion._active_dir, exist_ok=True) os.makedirs(_JOB_DIR, exist_ok=True) @@ -1802,10 +1801,9 @@ def test_sync_jobs_for_immediate_invalid_pending_jobs(self): mock_validator_cfg = MockAlwaysInvalidValidator.default_config() mock_append_to_job_history = mock.MagicMock() - with self._patch_bastion( - validator_cfg=mock_validator_cfg - ) as mock_bastion, mock.patch.object( - mock_bastion, "_append_to_job_history", mock_append_to_job_history + with ( + self._patch_bastion(validator_cfg=mock_validator_cfg) as mock_bastion, + mock.patch.object(mock_bastion, "_append_to_job_history", mock_append_to_job_history), ): os.makedirs(mock_bastion._active_dir, exist_ok=True) os.makedirs(_JOB_DIR, exist_ok=True) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index c405e8dfb..cfc68a33e 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -33,7 +33,6 @@ class _ServiceProtocol(enum.Enum): - """https://kubernetes.io/docs/reference/networking/service-protocols/""" TCP = "TCP" @@ -42,7 +41,6 @@ class _ServiceProtocol(enum.Enum): class _ServiceType(enum.Enum): - """https://cloud.google.com/kubernetes-engine/docs/concepts/service#types-of-services sss""" CLUSTER_IP = "ClusterIP" diff --git a/axlearn/cloud/gcp/k8s_service.py b/axlearn/cloud/gcp/k8s_service.py index b8e81ac60..a27c25046 100644 --- a/axlearn/cloud/gcp/k8s_service.py +++ b/axlearn/cloud/gcp/k8s_service.py @@ -1,4 +1,5 @@ -""" k8s service module.""" +"""k8s service module.""" + import copy import logging from typing import Any, Optional diff --git a/axlearn/cloud/gcp/measurement_test.py b/axlearn/cloud/gcp/measurement_test.py index 1d1f47161..05688e09e 100644 --- a/axlearn/cloud/gcp/measurement_test.py +++ b/axlearn/cloud/gcp/measurement_test.py @@ -177,9 +177,10 @@ def test_record_event_context_manager_handles_runtime_error(self): recorder = GoodputRecorder(cfg) with mock.patch("jax.process_index", return_value=0): - with mock.patch( - "ml_goodput_measurement.goodput.GoodputRecorder" - ) as mock_recorder_cls, mock.patch.object(logging, "warning") as mock_warning: + with ( + mock.patch("ml_goodput_measurement.goodput.GoodputRecorder") as mock_recorder_cls, + mock.patch.object(logging, "warning") as mock_warning, + ): mock_instance = mock_recorder_cls.return_value def raise_runtime_error(*args, **kwargs): diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 528ac543d..b4e1df10e 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -110,7 +110,7 @@ def get_pathways_tpu_version(gke_machine_type: str) -> str: def get_megascale_options( - xla_options: dict[str, Union[str, bool, int]] + xla_options: dict[str, Union[str, bool, int]], ) -> dict[str, Union[str, bool, int]]: """Filters XLA options for those pertaining to Megascale. @@ -125,7 +125,7 @@ def get_megascale_options( def get_xla_options( - xla_options: dict[str, Union[str, bool, int]] + xla_options: dict[str, Union[str, bool, int]], ) -> dict[str, Union[str, bool, int]]: """Filters XLA options for those starting with 'xla_'. @@ -962,9 +962,9 @@ def _build_head_container(self) -> dict: ], imagePullPolicy="Always", resources=resources, - ports=[dict(containerPort=self.config.target_port)] - if self.config.enable_service - else [], + ports=( + [dict(containerPort=self.config.target_port)] if self.config.enable_service else [] + ), ) def build_leader_pod(self) -> Nested[Any]: diff --git a/axlearn/cloud/gcp/tpu_health_check_test.py b/axlearn/cloud/gcp/tpu_health_check_test.py index 51772f6df..d2815fd8f 100644 --- a/axlearn/cloud/gcp/tpu_health_check_test.py +++ b/axlearn/cloud/gcp/tpu_health_check_test.py @@ -47,8 +47,11 @@ def test_parsing(self): def test_global_health_check(self): # On CPU CI, this should pass. - with mock.patch("os.kill") as mock_exit, mock.patch.dict( - os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + with ( + mock.patch("os.kill") as mock_exit, + mock.patch.dict( + os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + ), ): global_health_check("global=180", output_dir="") mock_exit.assert_not_called() @@ -63,10 +66,12 @@ def _check_failure_file(self, folder: str, keyword: str): self.fail("should not reach here") def test_global_health_check_timeout(self): - with mock.patch( - "os.kill" - ) as mock_exit, tempfile.TemporaryDirectory() as d, mock.patch.dict( - os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + with ( + mock.patch("os.kill") as mock_exit, + tempfile.TemporaryDirectory() as d, + mock.patch.dict( + os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + ), ): global_health_check("global=0.000001", output_dir=d) mock_exit.assert_called_once() @@ -79,10 +84,13 @@ def test_raises_with_no_megascale_env(self): pairwise_slice_health_check("pairwise=1", output_dir="") def test_global_health_check_failure(self): - with mock.patch("os.kill") as mock_exit, mock.patch( - f"{tpu_health_check_main.__name__}.main", lambda: False - ), tempfile.TemporaryDirectory() as d, mock.patch.dict( - os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + with ( + mock.patch("os.kill") as mock_exit, + mock.patch(f"{tpu_health_check_main.__name__}.main", lambda: False), + tempfile.TemporaryDirectory() as d, + mock.patch.dict( + os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + ), ): global_health_check("global=180", output_dir=d) mock_exit.assert_called_once() diff --git a/axlearn/common/adapter_torch.py b/axlearn/common/adapter_torch.py index b6b2620da..0faaec8f0 100644 --- a/axlearn/common/adapter_torch.py +++ b/axlearn/common/adapter_torch.py @@ -179,6 +179,7 @@ def _axlearn_weight_mapper(self, weight_name: str, weight: torch.Tensor) -> torc class LayerNorm(nn.LayerNorm, TorchModule): + # pylint: disable=useless-parent-delegation # maintains interface compatibility def __init__(self, *args, eps=1e-6, **kwargs): super().__init__(*args, eps=eps, **kwargs) diff --git a/axlearn/common/aot_compilation_test.py b/axlearn/common/aot_compilation_test.py index 4ea7f60ee..57c1e6c4d 100644 --- a/axlearn/common/aot_compilation_test.py +++ b/axlearn/common/aot_compilation_test.py @@ -1,4 +1,5 @@ """Tests aot_compilation utils.""" + from typing import cast from axlearn.common import test_utils diff --git a/axlearn/common/array_serialization_test.py b/axlearn/common/array_serialization_test.py index c086c7d5b..ecf240c8d 100644 --- a/axlearn/common/array_serialization_test.py +++ b/axlearn/common/array_serialization_test.py @@ -125,11 +125,13 @@ def transfer_to_host_patch(*args, **kwargs): return old_transfer(*args, **kwargs) d2h_future = array_serialization.futures.Future() - with mock.patch( - f"{array_serialization.__name__}.{_ts_open}", - ts_open_patch, - ), get_tensorstore_spec(arr) as spec, mock.patch( - f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch + with ( + mock.patch( + f"{array_serialization.__name__}.{_ts_open}", + ts_open_patch, + ), + get_tensorstore_spec(arr) as spec, + mock.patch(f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch), ): # Either RuntimeError(Array has been deleted with shape) or # ValueError(...Buffer has been deleted or donated...) may occur. @@ -151,11 +153,13 @@ def transfer_to_host_patch(*args, **kwargs): arr = self._create_partially_replicated_array(sharded) arr_host = jax.device_get(arr) d2h_future = array_serialization.futures.Future() - with mock.patch( - f"{array_serialization.__name__}.{_ts_open}", - ts_open_patch, - ), get_tensorstore_spec(arr) as spec, mock.patch( - f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch + with ( + mock.patch( + f"{array_serialization.__name__}.{_ts_open}", + ts_open_patch, + ), + get_tensorstore_spec(arr) as spec, + mock.patch(f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch), ): f = _CommitFuture( _run_serializer( @@ -185,10 +189,13 @@ async def ts_open_patch(*_, **__): raise RuntimeError("Test") d2h_future = array_serialization.futures.Future() - with mock.patch( - f"{array_serialization.__name__}.{_ts_open}", - ts_open_patch, - ), get_tensorstore_spec(arr) as spec: + with ( + mock.patch( + f"{array_serialization.__name__}.{_ts_open}", + ts_open_patch, + ), + get_tensorstore_spec(arr) as spec, + ): f = _CommitFuture( _run_serializer( [arr], [spec], [d2h_future], max_data_shard_degree=-1, shard_threshold_bytes=-1 @@ -202,10 +209,13 @@ def transfer_to_host_patch(*_): raise RuntimeError("Test") d2h_future = array_serialization.futures.Future() - with mock.patch( - f"{array_serialization.__name__}._transfer_to_host", - transfer_to_host_patch, - ), get_tensorstore_spec(arr) as spec: + with ( + mock.patch( + f"{array_serialization.__name__}._transfer_to_host", + transfer_to_host_patch, + ), + get_tensorstore_spec(arr) as spec, + ): f = _CommitFuture( _run_serializer( [arr], [spec], [d2h_future], max_data_shard_degree=-1, shard_threshold_bytes=-1 @@ -362,13 +372,15 @@ async def mock_ts_open(spec_arg, *args, **kwargs): return await original_ts_open(call_arg, *args, **kwargs) # Write the data to local files - with get_tensorstore_spec_for_deserialization(data) as ( - tensorstore_spec, - temp_path, - ), mock.patch( - f"{array_serialization.__name__}.{_ts_open}", new=mock_ts_open - ), mock.patch.dict( - "os.environ", {"JAX_PLATFORMS": jax_platforms, "ENABLE_GCS_GRPC": enable_gcs_grpc} + with ( + get_tensorstore_spec_for_deserialization(data) as ( + tensorstore_spec, + temp_path, + ), + mock.patch(f"{array_serialization.__name__}.{_ts_open}", new=mock_ts_open), + mock.patch.dict( + "os.environ", {"JAX_PLATFORMS": jax_platforms, "ENABLE_GCS_GRPC": enable_gcs_grpc} + ), ): manager.serialize( data, diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index f2162a3c9..b8c38ef90 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -474,9 +474,11 @@ def _get_spec(self, step: int, state: NestedTensor, ckpt_dir: str) -> Checkpoint spec.shardings.append( jax.sharding.NamedSharding( mesh, - jax.sharding.PartitionSpec() - if value.mesh_axes is None - else value.mesh_axes, + ( + jax.sharding.PartitionSpec() + if value.mesh_axes is None + else value.mesh_axes + ), ) ) elif isinstance(value, tf.data.Iterator): diff --git a/axlearn/common/inference_test.py b/axlearn/common/inference_test.py index d6d946ed3..20bdbf2be 100644 --- a/axlearn/common/inference_test.py +++ b/axlearn/common/inference_test.py @@ -779,10 +779,13 @@ def test_pipeline_summary_writer( mock_summary_writer = mock.Mock(return_value=None) - with mock.patch( - "axlearn.common.summary_writer.SummaryWriter.Config.instantiate", - mock.MagicMock(return_value=mock_summary_writer), - ), tempfile.TemporaryDirectory() as local_tmp_dir: + with ( + mock.patch( + "axlearn.common.summary_writer.SummaryWriter.Config.instantiate", + mock.MagicMock(return_value=mock_summary_writer), + ), + tempfile.TemporaryDirectory() as local_tmp_dir, + ): root_dir = local_tmp_dir if local_run else "gs://axlearn-public/testdata/inference_test" with set_data_dir(root_dir): prng_key = jax.random.PRNGKey(11) diff --git a/axlearn/common/input_lm.py b/axlearn/common/input_lm.py index e0f3d0ef4..f8c2b8fa0 100644 --- a/axlearn/common/input_lm.py +++ b/axlearn/common/input_lm.py @@ -628,7 +628,7 @@ def map_targets_out_of_class(example: dict[str, tf.Tensor]) -> dict[str, tf.Tens def _trim_and_pack_with_segments( - feature_lengths: dict[str, int] + feature_lengths: dict[str, int], ) -> input_tf_data.DatasetToDatasetFn: """Trim and pack inputs, injecting `*_segment_ids` and `*_positions`. @@ -683,7 +683,7 @@ def restore_intermediate_zeros(example: dict[str, tf.Tensor]): def _trim_and_pad_with_segments( - feature_lengths: dict[str, int] + feature_lengths: dict[str, int], ) -> input_tf_data.DatasetToDatasetFn: """Trim and pad inputs, injecting `*_segment_ids`, `*_positions`. diff --git a/axlearn/common/input_t5.py b/axlearn/common/input_t5.py index b47d13c3a..c55dc5329 100644 --- a/axlearn/common/input_t5.py +++ b/axlearn/common/input_t5.py @@ -232,7 +232,7 @@ def split_tokens( @seqio.map_over_dataset def split_tokens_example( - x: dict[str, tf.Tensor] + x: dict[str, tf.Tensor], ) -> tuple[dict[str, tf.Tensor], dict[str, tf.Tensor]]: """Split one token sequence into multiple sequences.""" tokens = x[input_key] diff --git a/axlearn/common/input_text.py b/axlearn/common/input_text.py index f2827b964..4d07acc30 100644 --- a/axlearn/common/input_text.py +++ b/axlearn/common/input_text.py @@ -179,7 +179,7 @@ def add_token_type_ids( input_key = [input_key] def example_fn( - example: dict[str, Union[tf.Tensor, tf.RaggedTensor]] + example: dict[str, Union[tf.Tensor, tf.RaggedTensor]], ) -> dict[str, Union[tf.Tensor, tf.RaggedTensor]]: token_type_ids = [] for i, key in enumerate(input_key): diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index 8914c6663..90220bf8d 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -1348,7 +1348,7 @@ def test_embed_with_constant_scale_validation(self): dim, num_embeddings, rng, - scale=Embedding.Scale.CONSTANT + scale=Embedding.Scale.CONSTANT, # Missing scale_constant ) self.assertIn("scale_constant must be specified", str(cm.exception)) diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 2fa4baac9..1b3f8e9b0 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -2033,9 +2033,9 @@ def _update2(u: Tensor, param: OptParam, weight_decay_scale: float = 1.0): params, per_param_scale=weight_decay_per_param_scale ) updates2 = jax.tree.map( - lambda u, p, wds: None - if u is None - else _update2(u, param=p, weight_decay_scale=wds), + lambda u, p, wds: ( + None if u is None else _update2(u, param=p, weight_decay_scale=wds) + ), updates, params, weight_decay_scales, diff --git a/axlearn/common/optimizers_test.py b/axlearn/common/optimizers_test.py index 6fe408239..e72d41f17 100644 --- a/axlearn/common/optimizers_test.py +++ b/axlearn/common/optimizers_test.py @@ -1310,8 +1310,8 @@ def test_param_ema(self, decay, dtype): self.assertEqual(new_state.count, 1) if isinstance(decay, float): - ema_fn = ( - lambda p: (1 - decay) * p.value + ema_fn = lambda p: ( + (1 - decay) * p.value if jnp.issubdtype(p.value.dtype, jnp.floating) else p.value ) diff --git a/axlearn/common/quantized_dot_general/utils.py b/axlearn/common/quantized_dot_general/utils.py index 292e42bf9..fa6ceeb7c 100644 --- a/axlearn/common/quantized_dot_general/utils.py +++ b/axlearn/common/quantized_dot_general/utils.py @@ -14,9 +14,7 @@ # Copyright 2024 The AQT Authors. # Licensed under the Apache License, Version 2.0 (the "License"). -"""QuantizedDotGeneral Utilities. Hosts default quantization configuration. - -""" +"""QuantizedDotGeneral Utilities. Hosts default quantization configuration.""" import functools import jax diff --git a/axlearn/common/rattention/kernels/linear_attention_kernels.py b/axlearn/common/rattention/kernels/linear_attention_kernels.py index d3542cc5f..2bb5f3381 100644 --- a/axlearn/common/rattention/kernels/linear_attention_kernels.py +++ b/axlearn/common/rattention/kernels/linear_attention_kernels.py @@ -1,6 +1,6 @@ # Copyright © 2025 Apple Inc. -""" Pallas kernels for Linear Attention (LA) specialized for sliding window attention. +"""Pallas kernels for Linear Attention (LA) specialized for sliding window attention. A specialized feature map from the following reference is used to support sliding window attention. The chunking strategy is similar to the one used in ssm_kernels/ssd_kernels.py. diff --git a/axlearn/common/rattention/rattention.py b/axlearn/common/rattention/rattention.py index 65f09b97d..669ab3e4a 100644 --- a/axlearn/common/rattention/rattention.py +++ b/axlearn/common/rattention/rattention.py @@ -1,4 +1,5 @@ """Implementation of RAttention with residual linear attention.""" + from functools import partial from typing import Callable, Optional, Union diff --git a/axlearn/common/rattention/rattention_test.py b/axlearn/common/rattention/rattention_test.py index 0b587fcc5..34f956d91 100644 --- a/axlearn/common/rattention/rattention_test.py +++ b/axlearn/common/rattention/rattention_test.py @@ -1,4 +1,5 @@ """Tests for RAttention and ResidualLinearAttention.""" + import copy import jax diff --git a/axlearn/common/rattention/utils.py b/axlearn/common/rattention/utils.py index 24559dca9..39c889721 100644 --- a/axlearn/common/rattention/utils.py +++ b/axlearn/common/rattention/utils.py @@ -1,4 +1,5 @@ """Utilities for RAttention.""" + from typing import Optional import jax diff --git a/axlearn/common/ssm_kernels/ssd_kernels.py b/axlearn/common/ssm_kernels/ssd_kernels.py index 9fa3bb244..3f96f275b 100644 --- a/axlearn/common/ssm_kernels/ssd_kernels.py +++ b/axlearn/common/ssm_kernels/ssd_kernels.py @@ -1,6 +1,6 @@ # Copyright © 2024 Apple Inc. -""" Pallas kernels for Mamba2 +"""Pallas kernels for Mamba2 High-level idea: this kernel implements a two-level chunking algorithm to balance memory consumption and running speed. Intuitively, we store chunk-level diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 17055f2e2..99f905699 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -237,6 +237,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding else: + # pylint: disable=not-callable # pad is callable, false positive conv_state = torch.nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) diff --git a/axlearn/vision/coco_utils.py b/axlearn/vision/coco_utils.py index e86395eb7..1d9b04540 100644 --- a/axlearn/vision/coco_utils.py +++ b/axlearn/vision/coco_utils.py @@ -65,6 +65,7 @@ def __init__(self, eval_type="box", annotation_file=None, gt_dataset=None): self.dataset = gt_dataset self.createIndex() + # pylint: disable=invalid-name # matches COCO API naming convention def loadRes(self, predictions: list[dict[str, Any]]) -> coco.COCO: """Loads result file and return a result api object.