diff --git a/recml/core/ops/embedding_ops.py b/recml/core/ops/embedding_ops.py index c094596..a1de4f0 100644 --- a/recml/core/ops/embedding_ops.py +++ b/recml/core/ops/embedding_ops.py @@ -38,7 +38,7 @@ class SparsecoreParams: """Embedding parameters.""" feature_specs: Nested[FeatureSpec] - abstract_mesh: jax.sharding.AbstractMesh + mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh data_axes: Sequence[str | None] embedding_axes: Sequence[str | None] sharding_strategy: str @@ -53,11 +53,11 @@ def sparsecore_lookup( return shard_map.shard_map( functools.partial( embedding.tpu_sparse_dense_matmul, - global_device_count=sparsecore_params.abstract_mesh.size, + global_device_count=sparsecore_params.mesh.size, feature_specs=sparsecore_params.feature_specs, sharding_strategy=sparsecore_params.sharding_strategy, ), - mesh=sparsecore_params.abstract_mesh, + mesh=sparsecore_params.mesh, in_specs=( jax.sharding.PartitionSpec(*sparsecore_params.data_axes), jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes), @@ -90,7 +90,7 @@ def _emb_lookup_bwd( feature_specs=sparsecore_params.feature_specs, sharding_strategy=sparsecore_params.sharding_strategy, ), - mesh=sparsecore_params.abstract_mesh, + mesh=sparsecore_params.mesh, in_specs=( jax.sharding.PartitionSpec(*sparsecore_params.data_axes), jax.sharding.PartitionSpec(*sparsecore_params.data_axes), diff --git a/recml/core/training/partitioning.py b/recml/core/training/partitioning.py index 2eda2e0..4dc3b76 100644 --- a/recml/core/training/partitioning.py +++ b/recml/core/training/partitioning.py @@ -20,7 +20,6 @@ import flax.linen as nn import jax -from jax.experimental import mesh_utils import numpy as np @@ -68,7 +67,7 @@ class DataParallelPartitioner(Partitioner): """Data parallel partitioner.""" def __init__(self, data_axis: str = "batch"): - self.mesh = jax.sharding.Mesh(jax.devices(), (data_axis,)) + self.mesh = jax.make_mesh((jax.device_count(),), (data_axis,)) self.data_sharding = jax.sharding.NamedSharding( self.mesh, jax.sharding.PartitionSpec(data_axis) ) @@ -109,6 +108,12 @@ def partition_init( self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None ) -> CreateStateFn: with jax.sharding.use_mesh(self.mesh): + if abstract_batch is not None: + abstract_state = jax.eval_shape(init_fn, abstract_batch) + specs = nn.get_partition_spec(abstract_state) + self.state_sharding = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self.mesh, x), specs + ) init_fn = jax.jit(init_fn, out_shardings=self.state_sharding) def _wrapped_init(batch: PyTree) -> State: @@ -145,12 +150,12 @@ class ModelParallelPartitioner(Partitioner): This only works with multi-controller Jax, i.e. communications along the ICI for TPUs. For scaling beyond a single TPU slice this needs to be extended to support Megascale XLA or single-controller Pathways. Consider using T5X, Pax, - or Gemax for these use cases. + MaxText externally or Gemax internally for these use cases. - Note: This assumes that all axes of the inputs except the final one are used - for data parallelism while the final one is used for model parallelism. - This tends to work well for 2D and 3D torus topologies since network latency - tends to be much higher for the leading axes. + By default, all axes of the input are used for data parallelism. This results + in fully-sharded data-parallelism for ND topologies or data-parallelism for 1D + topologies. The range of axes can be configured using the `dp_axes` argument, + i.e. axes[:dp_axes] will be used for data parallelism. IMPORTANT: `shard_inputs` operates on a per process batch. This means that the input batch size on CPU must already be the per process batch size, @@ -160,45 +165,49 @@ class ModelParallelPartitioner(Partitioner): def __init__( self, - axes: Sequence[tuple[str, int]], + axes: Sequence[tuple[str, int]] = (("batch", -1),), + dp_axes: int | None = None, rules: Mapping[str, str] | None = None, aot_compile: bool = False, options: jax.stages.CompilerOptions | None = None, + devices: Sequence[jax.Device] | None = None, ): - if len(axes) < 2: + if not axes: + raise ValueError("At least one axis must be specified in `axes`.") + if dp_axes == 0: + raise ValueError( + "Data parallelism axes range must be positive or negative." + ) + + devices = devices if devices is not None else jax.devices() + axis_names = [axis for axis, _ in axes] + axis_sizes = [dim for _, dim in axes] + if any(dim <= 0 for dim in axis_sizes[1:]): raise ValueError( - "`axes` cannot less than 2D, use data-parallel" - f" partitioner instead. Got axes: {axes}." + "All dimensions except the first in the axes must be positive" + f" integers. Got axes: {axes}." ) + if axis_sizes[0] == -1: + axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:]) - mesh_devices = mesh_utils.create_device_mesh([dim for _, dim, in axes]) - self.mesh = jax.sharding.Mesh(mesh_devices, [axis for axis, _ in axes]) + self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices) self.rules = rules self.aot_compile = aot_compile self.options = options - dp_axes, dp_dims = zip(*axes[:-1]) - _, mp_dim = axes[-1] - - if math.prod(dp_dims) % jax.process_count() != 0: + dp_axis_names, dp_axis_sizes = zip(*axes[:dp_axes]) + num_processes = jax.process_count() + if math.prod(dp_axis_sizes) % num_processes != 0: raise ValueError( "The data parallel dimensions in the mesh must be divisible by the" " number of processes as we assume data parallelism across" - f" processes. Got process count: {jax.process_count()} and data" - f" parallelism dimensions: {dp_dims} for axes: {axes} and mesh" - f" devices: {self.mesh.devices}." - ) - if jax.local_device_count() % mp_dim != 0: - raise ValueError( - "The number of local devices on each host must be divisible by the" - " model dimension as we assume model parallelism across local" - f" devices. Got local device count: {jax.local_device_count()} and" - f" model parallelism dimension: {mp_dim} for axes: {axes} and mesh" + f" processes. Got process count: {num_processes} and data" + f" parallelism dimensions: {dp_axis_sizes} for axes: {axes} and mesh" f" devices: {self.mesh.devices}." ) self.data_sharding = jax.sharding.NamedSharding( - self.mesh, jax.sharding.PartitionSpec(dp_axes) + self.mesh, jax.sharding.PartitionSpec(dp_axis_names) ) self.state_sharding = None self.abstract_batch = None diff --git a/recml/core/training/partitioning_test.py b/recml/core/training/partitioning_test.py index 55d286f..5fa95c6 100644 --- a/recml/core/training/partitioning_test.py +++ b/recml/core/training/partitioning_test.py @@ -40,7 +40,7 @@ def test_data_parallelism( self, partitioner_cls: type[partitioning.Partitioner] ): if partitioner_cls is partitioning.ModelParallelPartitioner: - kwargs = {"axes": [("data", jax.device_count()), ("model", 1)]} + kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1} else: kwargs = {} partitioner = partitioner_cls(**kwargs) @@ -113,7 +113,7 @@ def _eval_step( def test_model_parallelism(self): partitioner = partitioning.ModelParallelPartitioner( - axes=[("data", 1), ("model", jax.device_count())] + axes=[("data", 1), ("model", jax.device_count())], dp_axes=1 ) inputs = np.zeros((128, 16), dtype=np.float32) diff --git a/recml/examples/dlrm_experiment.py b/recml/examples/dlrm_experiment.py index 1f09d60..a373158 100644 --- a/recml/examples/dlrm_experiment.py +++ b/recml/examples/dlrm_experiment.py @@ -99,19 +99,19 @@ class DLRMModel(nn.Module): dcn_layers: int dcn_inner_dim: int - # We need to track the embedder on the Flax module to ensure it is not - # re-created on cloning. It is not possible to create an embedder inside - # setup() because it is called lazily at compile time. The embedder needs + # We need to track the sparsecore config on the Flax module to ensure it is + # not re-created on cloning. It is not possible to create an config inside + # setup() because it is called lazily at compile time. The config needs # to be created before `model.init` so we can use it to create a preprocessor. - # A simpler pattern that works is passing `embedder` directly to the module. - _embedder: sparsecore.SparsecoreEmbedder | None = None + # A simpler pattern that works is passing the config directly to the module. + _sparsecore_config: sparsecore.SparsecoreConfig | None = None @property - def embedder(self) -> sparsecore.SparsecoreEmbedder: - if self._embedder is not None: - return self._embedder + def sparsecore_config(self) -> sparsecore.SparsecoreConfig: + if self._sparsecore_config is not None: + return self._sparsecore_config - embedder = sparsecore.SparsecoreEmbedder( + sparsecore_config = sparsecore.SparsecoreConfig( specs={ f.name: sparsecore.EmbeddingSpec( input_dim=f.vocab_size, @@ -123,8 +123,8 @@ def embedder(self) -> sparsecore.SparsecoreEmbedder: }, optimizer=self.embedding_optimizer, ) - object.__setattr__(self, '_embedder', embedder) - return embedder + object.__setattr__(self, '_sparsecore_config', sparsecore_config) + return sparsecore_config def bottom_mlp(self, inputs: Mapping[str, jt.Array]) -> jt.Array: x = jnp.concatenate( @@ -174,7 +174,9 @@ def __call__( self, inputs: Mapping[str, jt.Array], training: bool = False ) -> jt.Array: dense_embeddings = self.bottom_mlp(inputs) - sparse_embeddings = self.embedder.make_sparsecore_module()(inputs) + sparse_embeddings = sparsecore.SparsecoreEmbed( + self.sparsecore_config, name='sparsecore_embed' + )(inputs) sparse_embeddings = jax.tree.flatten(sparse_embeddings)[0] concatenated_embeddings = jnp.concatenate( (dense_embeddings, *sparse_embeddings), axis=-1 @@ -239,11 +241,15 @@ def create_datasets(self) -> tuple[recml.data.Iterator, recml.data.Iterator]: global_batch_size = self.train_data.global_batch_size train_iter = recml.data.TFDatasetIterator( dataset=self.train_data.make(), - postprocessor=self.model.embedder.make_preprocessor(global_batch_size), + postprocessor=sparsecore.SparsecorePreprocessor( + self.model.sparsecore_config, global_batch_size + ), ) eval_iter = recml.data.TFDatasetIterator( dataset=self.eval_data.make(), - postprocessor=self.model.embedder.make_preprocessor(global_batch_size), + postprocessor=sparsecore.SparsecorePreprocessor( + self.model.sparsecore_config, global_batch_size + ), ) return train_iter, eval_iter diff --git a/recml/layers/linen/sparsecore.py b/recml/layers/linen/sparsecore.py index ad59cfc..6c02b90 100644 --- a/recml/layers/linen/sparsecore.py +++ b/recml/layers/linen/sparsecore.py @@ -22,24 +22,19 @@ from etils import epy from flax import linen as nn -from flax import typing import jax -from jax.experimental import layout import jax.numpy as jnp import numpy as np from recml.core.ops import embedding_ops import tensorflow as tf -if jax.__version_info__ >= (0, 6, 3): - DLL = layout.Layout -else: - DLL = layout.DeviceLocalLayout # type: ignore - with epy.lazy_imports(): # pylint: disable=g-import-not-at-top + from jax_tpu_embedding.sparsecore.lib.flax import embed from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec + from jax_tpu_embedding.sparsecore.lib.nn import table_stacking from jax_tpu_embedding.sparsecore.utils import utils # pylint: enable=g-import-not-at-top @@ -49,6 +44,7 @@ OptimizerSpec = Any +# TODO(aahil): This should be common between Keras, Flax, NNX. @dataclasses.dataclass class EmbeddingSpec: """Sparsecore embedding spec. @@ -92,12 +88,15 @@ def __hash__(self): @dataclasses.dataclass -class SparsecoreEmbedder: - """Sparsecore embedder. +class SparsecoreConfig: + """Sparsecore embedding configuration. Attributes: specs: A mapping from feature name to embedding specs. optimizer: The default optimizer to use for the embedding variables. + sharding_axis: The axis to use for sharding the embedding table. Can be + either an integer mesh axis index or a string mesh axis name. Defaults to + axis 0. sharding_strategy: The sharding strategy to use for the embedding table. Defaults to 'MOD' sharding. See the sparsecore documentation for more details. @@ -105,17 +104,29 @@ class SparsecoreEmbedder: fixed mapping is used to determine this based on device 0. This may fail on newer TPU architectures if the mapping is not updated of if device 0 is not a TPU device with a sparsecore. + max_ids_per_partition_fn: A function that accepts the name of the table and + its inputs size and returns the maximum number of IDs to process per + partition. Defaults to the size of the inputs. + max_unique_ids_per_partition_fn: A function that accepts the name of the + table and its inputs size and returns the maximum number of unique IDs to + on each partition. Defaults to the size of the inputs. + local_device_count: The number of Jax devices in the local process. Defaults + to `jax.local_device_count`. + global_device_count: The number of Jax devices in the global process. + Defaults to `jax.device_count`. + num_sc_per_device: The number of sparsecores per Jax device. If not set, + tries to fetch it from a fixed mapping. Example usage: ```python class DLRMModel(nn.Module): - # The embedder must be a property of the Flax model and cannot be created + # The config must be a property of the Flax model and cannot be created # inside setup(). - embedder: sparsecore.SparsecoreEmbedder + sparsecore_config: sparsecore.SparsecoreConfig ... def setup(self): - self.sparsecore_module = self.embedder.make_sparsecore_module() + self.sparsecore_module = SparsecoreEmbed(self.sparsecore_config) ... def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array: @@ -123,10 +134,10 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array: ... # Instantiate the model and the embedder. - model = DLRMModel(embedder=embedder) + model = DLRMModel(sparsecore_config, ...) # Create the eager preprocessor. - preprocessor = model.embedder.make_preprocessor(global_batch_size) + preprocessor = SparsecorePreprocessor(sparsecore_config, global_batch_size) # Fetch and preprocess the inputs on CPU. inputs = ... @@ -138,26 +149,70 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array: sharded_inputs = ... # Initialize and call the model on TPU inside JIT as usual. - vars = model.init(jax.random.key(0), sharded_inputs) + vars = model.init(jax.random.key(0), ...) embedding_activations = model.apply(vars, sharded_inputs) ``` """ specs: Mapping[str, EmbeddingSpec] optimizer: OptimizerSpec + sharding_axis: str | int = 0 sharding_strategy: str = 'MOD' - def __post_init__(self): - self._feature_specs = None - self._global_batch_size = None - self._num_sc_per_device = utils.num_sparsecores_per_device() + # TODO(aahil): Come up with better defaults / heuristics here. + max_ids_per_partition_fn: Callable[[str, int], int] = dataclasses.field( + default=lambda n, bs: bs + ) + max_unique_ids_per_partition_fn: Callable[[str, int], int] = ( + dataclasses.field(default=lambda n, bs: bs) + ) + + # Optional device information. + local_device_count: int = dataclasses.field( + default_factory=jax.local_device_count + ) + global_device_count: int = dataclasses.field(default_factory=jax.device_count) + num_sc_per_device: int = dataclasses.field( + default_factory=utils.num_sparsecores_per_device + ) + + _feature_specs: Mapping[str, embedding_ops.FeatureSpec] | None = ( + dataclasses.field(init=False, default=None) + ) + _global_batch_size: int | None = dataclasses.field(init=False, default=None) - def _init_feature_specs( - self, batch_size: int - ) -> Mapping[str, embedding_ops.FeatureSpec]: + @property + def feature_specs(self) -> Mapping[str, embedding_ops.FeatureSpec]: """Returns the feature specs for sparsecore embedding lookup.""" - if self._feature_specs is not None: - return self._feature_specs + if self._feature_specs is None: + raise ValueError( + 'The feature specs are not initialized. Make sure to call' + ' `init_feature_specs` before accessing the' + ' feature specs.' + ) + return self._feature_specs + + @property + def global_batch_size(self) -> int: + """Returns the global batch size for sparsecore embedding lookup.""" + if self._global_batch_size is None: + raise ValueError( + 'The global batch size is not initialized. Make sure to call' + ' `init_feature_specs` before accessing the' + ' global batch size.' + ) + return self._global_batch_size + + def init_feature_specs(self, batch_size: int): + """Creates the feature specs for sparsecore embedding lookup.""" + if self._feature_specs is not None and self._global_batch_size is not None: + if batch_size != self._global_batch_size: + raise ValueError( + 'The batch size is already initialized to' + f' {self._global_batch_size}. It cannot be changed to' + f' {batch_size}.' + ) + return feature_specs = {} shared_tables = {} @@ -194,28 +249,40 @@ def _init_feature_specs( embedding.auto_stack_tables( feature_specs, - jax.device_count(), - self._num_sc_per_device, - stack_to_max_ids_per_partition=lambda n, bs: bs, - stack_to_max_unique_ids_per_partition=lambda n, bs: bs, + self.global_device_count, + self.num_sc_per_device, + stack_to_max_ids_per_partition=self.max_ids_per_partition_fn, + stack_to_max_unique_ids_per_partition=self.max_unique_ids_per_partition_fn, ) embedding.prepare_feature_specs_for_training( feature_specs, - jax.device_count(), - self._num_sc_per_device, + self.global_device_count, + self.num_sc_per_device, ) self._feature_specs = feature_specs self._global_batch_size = batch_size - return feature_specs - def make_preprocessor(self, batch_size: int) -> Callable[..., Any]: + +@dataclasses.dataclass +class SparsecorePreprocessor: + """Preprocessor for sparsecore embedding lookup. + + Attributes: + sparsecore_config: The sparsecore config used to create the tables. + global_batch_size: The global batch size across all devices to partition the + inputs across. + """ + + sparsecore_config: SparsecoreConfig + global_batch_size: int + + def __post_init__(self): + self.sparsecore_config.init_feature_specs(self.global_batch_size) + + def __call__( + self, inputs: Mapping[str, Any] | tuple[Mapping[str, Any], ...] + ) -> Mapping[str, Any] | tuple[Mapping[str, Any], ...]: """Returns a preprocessor for sparsecore embedding lookup.""" - feature_specs = self._init_feature_specs(batch_size) - weights_names = { - name: spec.weight_name - for name, spec in self.specs.items() - if spec.weight_name is not None - } def _to_np(x: Any) -> np.ndarray: if isinstance(x, np.ndarray): @@ -223,86 +290,76 @@ def _to_np(x: Any) -> np.ndarray: if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): raise NotImplementedError( 'Sparsecore embedding layer does not support sparse or' - ' raggedtensors.' + ' ragged tensors yet.' ) if isinstance(x, tf.Tensor): - return x.numpy() + return x.numpy() # pylint: disable=attribute-error if isinstance(x, jax.Array): return jax.device_get(x) return np.array(x) - def _preprocessor(inputs): - if isinstance(inputs, tuple): - inputs, *rem = inputs + if isinstance(inputs, tuple): + rem = inputs[1:] + inputs = inputs[0] + else: + rem = None + + sparse_features = set() + features = {} + weights = {} + for key in self.sparsecore_config.feature_specs: + features[key] = _to_np(inputs[key]) + sparse_features.add(key) + if self.sparsecore_config.specs[key].weight_name is not None: + weights[key] = _to_np( + inputs[self.sparsecore_config.specs[key].weight_name] + ) + sparse_features.add(self.sparsecore_config.specs[key].weight_name) else: - rem = None - - sparse_features = set() - features = {} - weights = {} - for key in feature_specs: - features[key] = _to_np(inputs[key]) - sparse_features.add(key) - if key in weights_names: - weights[key] = _to_np(inputs[weights_names[key]]) - sparse_features.add(weights_names[key]) - else: - weights[key] = np.ones_like(features[key]) - - csr_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( - features=features, - features_weights=weights, - feature_specs=feature_specs, - local_device_count=jax.local_device_count(), - global_device_count=jax.device_count(), - num_sc_per_device=self._num_sc_per_device, - sharding_strategy=self.sharding_strategy, - allow_id_dropping=False, - ) - - processed_inputs = { - k: v for k, v in inputs.items() if k not in sparse_features - } - processed_inputs[CSR_INPUTS_KEY] = csr_inputs - - if rem is not None: - processed_inputs = (processed_inputs, *rem) - return processed_inputs + weights[key] = np.ones_like(features[key]) + + if self.sparsecore_config.specs[key].max_sequence_length is not None: + features[key] = np.reshape(features[key], (-1, 1)) + if weights[key] is not None: + weights[key] = np.reshape(weights[key], (-1, 1)) + + csr_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + features=features, + features_weights=weights, + feature_specs=self.sparsecore_config.feature_specs, + local_device_count=self.sparsecore_config.local_device_count, + global_device_count=self.sparsecore_config.global_device_count, + num_sc_per_device=self.sparsecore_config.num_sc_per_device, + sharding_strategy=self.sparsecore_config.sharding_strategy, + allow_id_dropping=False, + ) - return _preprocessor + processed_inputs = { + k: v for k, v in inputs.items() if k not in sparse_features + } + processed_inputs[CSR_INPUTS_KEY] = csr_inputs - def make_sparsecore_module(self, **kwargs) -> _SparsecoreEmbed: - """Returns the sparsecore embedding layer.""" - if self._feature_specs is None or self._global_batch_size is None: - raise ValueError( - 'The feature specs are not initialized. Make sure to call' - ' `make_preprocessor` before calling `sparsecore_layer`.' - ) + if rem is not None: + processed_inputs = (processed_inputs, *rem) + return processed_inputs - def _key(k: str | tuple[str, str]) -> str: - return k[0] if isinstance(k, tuple) else k - return _SparsecoreEmbed( - feature_specs=self._feature_specs, - global_batch_size=self._global_batch_size, - sharding_axis=0, - sharding_strategy=self.sharding_strategy, - num_sc_per_device=self._num_sc_per_device, - **kwargs, - ) +class SparsecoreEmbed(nn.Module): + """Sparsecore embedding layer. + Attributes: + sparsecore_config: A sparsecore config specifying how to create the tables. + mesh: The mesh to use for the embedding layer. If not provided, the global + mesh set by `jax.sharding.use_mesh` will be used. If neither is set, + an error will be raised. + """ -class _SparsecoreEmbed(nn.Module): - """Sparsecore embedding layer.""" - - feature_specs: embedding_ops.Nested[embedding_ops.FeatureSpec] - global_batch_size: int - sharding_axis: str | int - sharding_strategy: str - num_sc_per_device: int + sparsecore_config: SparsecoreConfig + mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None = None - @property - def abstract_mesh(self) -> jax.sharding.AbstractMesh: + def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh: + if self.mesh is not None: + return self.mesh abstract_mesh = jax.sharding.get_abstract_mesh() if not abstract_mesh.shape_tuple: raise ValueError( @@ -311,32 +368,33 @@ def abstract_mesh(self) -> jax.sharding.AbstractMesh: ) return abstract_mesh - @property - def sharding_axis_name(self) -> str: - if isinstance(self.sharding_axis, int): - return self.abstract_mesh.axis_names[self.sharding_axis] - return self.sharding_axis - - @property - def num_shards(self) -> int: - return self.abstract_mesh.shape[self.sharding_axis_name] + def get_sharding_axis( + self, mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh + ) -> str: + if isinstance(self.sparsecore_config.sharding_axis, int): + return mesh.axis_names[self.sparsecore_config.sharding_axis] + return self.sparsecore_config.sharding_axis def setup(self): + mesh = self.get_mesh() + sharding_axis_name = self.get_sharding_axis(mesh) + initializer = functools.partial( embedding.init_embedding_variables, - table_specs=embedding.get_table_specs(self.feature_specs), + table_specs=embedding.get_table_specs( + self.sparsecore_config.feature_specs + ), global_sharding=jax.sharding.NamedSharding( - self.abstract_mesh, - jax.sharding.PartitionSpec(self.sharding_axis_name, None), + mesh, jax.sharding.PartitionSpec(sharding_axis_name, None) ), - num_sparsecore_per_device=self.num_sc_per_device, - # We need to by-pass the mesh check to use the abstract mesh. - bypass_mesh_check=True, + num_sparsecore_per_device=self.sparsecore_config.num_sc_per_device, + # We need to by-pass the mesh check to allow using an abstract mesh. + bypass_mesh_check=isinstance(mesh, jax.sharding.AbstractMesh), ) self.embedding_table = self.param( name=EMBEDDING_PARAM_NAME, - init_fn=_with_sparsecore_layout( - initializer, (self.sharding_axis_name,), self.abstract_mesh + init_fn=embed.with_sparsecore_layout( + initializer, (sharding_axis_name,), mesh # type: ignore ), ) @@ -353,13 +411,15 @@ def __call__( Returns: The activations structure with the same structure as specs. """ + mesh = self.get_mesh() + sharding_axis_name = self.get_sharding_axis(mesh) activations = embedding_ops.sparsecore_lookup( embedding_ops.SparsecoreParams( - feature_specs=self.feature_specs, - abstract_mesh=self.abstract_mesh, - data_axes=(self.sharding_axis_name,), - embedding_axes=(self.sharding_axis_name, None), - sharding_strategy=self.sharding_strategy, + feature_specs=self.sparsecore_config.feature_specs, + mesh=mesh, + data_axes=(sharding_axis_name,), + embedding_axes=(sharding_axis_name, None), + sharding_strategy=self.sparsecore_config.sharding_strategy, ), self.embedding_table, inputs[CSR_INPUTS_KEY], @@ -368,12 +428,12 @@ def __call__( # Reshape the activations if the batch size is not the same as the global # batch size. def _maybe_reshape_activation(activation: jax.Array) -> jax.Array: - if activation.shape[0] != self.global_batch_size: + if activation.shape[0] != self.sparsecore_config.global_batch_size: return jnp.reshape( activation, ( - self.global_batch_size, - activation.shape[0] // self.global_batch_size, + self.sparsecore_config.global_batch_size, + activation.shape[0] // self.sparsecore_config.global_batch_size, activation.shape[1], ), ) @@ -382,23 +442,182 @@ def _maybe_reshape_activation(activation: jax.Array) -> jax.Array: return jax.tree.map(_maybe_reshape_activation, activations) -class SparsecoreLayout(nn.Partitioned[A]): +def gather_table( + sparsecore_config: SparsecoreConfig, + sc_params: Mapping[str, embedding.EmbeddingVariables], + name: str, +) -> jax.Array: + """Gathers a table from a stacked table on device. + + Args: + sparsecore_config: The sparsecore config used to create the tables. + sc_params: A mapping from table name to the embedding variables. This must + consist of the Flax variables corresponding to the sparsecore module + returned by `SparsecoreEmbed`. + name: The name of the table to gather. + + Returns: + The unstacked and unsharded embedding table on device. + """ - def get_sharding(self, _): - assert self.mesh is not None - return layout.Format( - DLL(major_to_minor=(0, 1), _tiling=((8,),)), - jax.sharding.NamedSharding(self.mesh, self.get_partition_spec()), + embedding_specs = embedding.create_proto_from_feature_specs( + feature_specs=sparsecore_config.feature_specs, + global_device_count=sparsecore_config.global_device_count, + num_sparsecore_per_device=sparsecore_config.num_sc_per_device, + ) + + for stacked_table_spec in embedding_specs.stacked_table_specs: + table_specs = {ts.table_name: ts for ts in stacked_table_spec.table_specs} + if f'{name}_table' not in table_specs: + continue + + table_spec = table_specs[f'{name}_table'] + + num_sparsecores = stacked_table_spec.num_sparsecores + stacked_embedding_dim = stacked_table_spec.stack_embedding_dim + row_offset = table_spec.row_offset_in_shard + chunk_size = table_spec.padded_vocab_size // num_sparsecores + rotation = table_spec.shard_rotation + vocab_size = table_spec.vocab_size + embedding_dim = table_spec.embedding_dim + + stacked_table = sc_params[stacked_table_spec.stack_name] + stacked_table_3d = jnp.reshape( + stacked_table.table, (num_sparsecores, -1, stacked_embedding_dim) ) + shards = stacked_table_3d[:, row_offset : row_offset + chunk_size, :] + + # Undo the shard rotation (note '-' for reverse direction) + shards = jnp.roll(shards, -rotation, axis=0) + + # Undo the mod sharding + un_mod_shard = shards.transpose((1, 0, 2)) + + # Remove the first dimension + table = un_mod_shard.reshape(-1, stacked_embedding_dim) -def _with_sparsecore_layout( - fn: Callable[..., Any], - names: typing.LogicalNames, - abstract_mesh: jax.sharding.AbstractMesh, -): - @functools.wraps(fn) - def wrapper(*args, **kwargs): - return SparsecoreLayout(fn(*args, **kwargs), names, mesh=abstract_mesh) # pytype: disable=wrong-arg-types + # Remove paddings + table = table[:vocab_size, :embedding_dim] - return wrapper + return table + + raise ValueError( + f'Table {name} not found in feature specs' + f' {sparsecore_config.feature_specs}.' + ) + + +def fetch_tables( + sparsecore_config: SparsecoreConfig, + sc_params: Mapping[str, embedding.EmbeddingVariables], + as_tf_variables: bool = True, + donate: bool = True, +) -> Mapping[str, jax.Array] | Mapping[str, tf.Variable]: + """Unstacks and unshards the stacked tables and fetches them to the host. + + Args: + sparsecore_config: The sparsecore config used to create the tables. + sc_params: A mapping from table name to the embedding variables. This must + consist of the Flax variables corresponding to the sparsecore module + returned by `make_sparsecore_module`. + as_tf_variables: Whether to return the tables as TF variables. Defaults to + False. + donate: Whether to donate the stacked tables, i.e. remove them from device + HBM, to save memory. Defaults to True. + + Returns: + A mapping from table name to the unstacked and unsharded embedding table + on the host. + """ + tables = table_stacking.unstack_and_unshard_stacked_tables( + stacked_tables={k: v.table for k, v in sc_params.items()}, + embedding_specs=embedding.create_proto_from_feature_specs( + feature_specs=sparsecore_config.feature_specs, + global_device_count=sparsecore_config.global_device_count, + num_sparsecore_per_device=sparsecore_config.num_sc_per_device, + ), + donate=donate, + ) + tables = {k.removesuffix('_table'): v for k, v in tables.items()} + if as_tf_variables: + tables = { + k: tf.Variable(tf.convert_to_tensor(v)) for k, v in tables.items() + } + return tables + + +def cpu_lookup( + sparsecore_config: SparsecoreConfig, + tables: Mapping[str, tf.Variable], + inputs: Mapping[str, tf.Tensor | tf.SparseTensor | tf.RaggedTensor], +) -> Mapping[str, tf.Tensor]: + """Performs embedding lookups on the host. + + Args: + sparsecore_config: The sparsecore config used to create the tables. + tables: A mapping of the embedding tables on the host. This must in the same + format as the output of `fetch_tables`. + inputs: A mapping of the input features on the host. This must be in the + same format as the input to the preprocessor created by + `make_preprocessor`. + + Returns: + A mapping of the embedding activations on the host. This has the same + structure as the output of the sparsecore module. + """ + + activations = {} + for name, spec in sparsecore_config.specs.items(): + feature = inputs[name] + weight = inputs[spec.weight_name] if spec.weight_name is not None else None + if isinstance(feature, tf.Tensor): + activation = tf.nn.embedding_lookup(tables[name], feature) + + if spec.max_sequence_length is None: + activation = _reduce(activation, weight, spec.combiner) + + activations[name] = activation + else: + raise NotImplementedError( + 'Sparsecore embedding layer does not support sparse or ragged' + ' tensors yet.' + ) + + return activations + + +def _reduce( + inputs: tf.Tensor | tf.RaggedTensor, + weights: tf.Tensor | tf.RaggedTensor | None = None, + combiner: Literal['sum', 'mean', 'sqrtn'] = 'mean', +) -> tf.Tensor: + """Performs a weighted reduction across the penultimate dimension of a tensor. + + Args: + inputs: A dense or ragged tensor of shape [D_1, ..., D_N] to reduce. + weights: Optional weights to apply to the reduction. If given, the + dimensions of the weights must be [D_1, ..., D_N][:`axis` + 1]. Note that + the if the inputs have ragged dimensions, the weights must have the same + ragged dimensions. + combiner: The combiner to use for the reduction. Can be one of ['sum', + 'mean', 'sqrtn']. + + Returns: + The reduced inputs of rank N - 1. + """ + if weights is not None: + weights = tf.expand_dims(tf.cast(weights, inputs.dtype), axis=-1) + inputs = inputs * weights + + out = tf.reduce_sum(inputs, axis=-2) + if combiner == 'mean': + weight_sum = tf.reduce_sum(weights, axis=-2) + out = tf.math.divide_no_nan(out, weight_sum) + elif combiner == 'sqrtn': + weight_sum = tf.math.sqrt(tf.reduce_sum(weights**2, axis=-2)) + out = tf.math.divide_no_nan(out, weight_sum) + else: + raise ValueError("`combiner` must be one of ['mean', 'sqrtn', 'sum'].") + + return out diff --git a/recml/layers/linen/sparsecore_test.py b/recml/layers/linen/sparsecore_test.py index 8728610..6ec24f0 100644 --- a/recml/layers/linen/sparsecore_test.py +++ b/recml/layers/linen/sparsecore_test.py @@ -18,8 +18,10 @@ from absl.testing import absltest from etils import epy import jax +import numpy as np from recml.core.training import partitioning from recml.layers.linen import sparsecore +import tensorflow as tf with epy.lazy_imports(): from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec # pylint: disable=g-import-not-at-top @@ -31,38 +33,43 @@ def test_sparsecore_embedder_equivalence(self): if jax.devices()[0].platform != "tpu": self.skipTest("Test only supported on TPUs.") - k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) + tf.random.set_seed(0) inputs = { - "a": jax.random.randint(k1, (32, 16), minval=1, maxval=100), - "b": jax.random.randint(k2, (32, 16), minval=1, maxval=100), - "w": jax.random.normal(k3, (32, 16)), + "a": tf.random.uniform( + shape=(32, 16), minval=1, maxval=100, dtype=tf.int64 + ), + "b": tf.random.uniform( + shape=(32, 16), minval=1, maxval=100, dtype=tf.int64 + ), + "w": tf.random.normal(shape=(32, 16), dtype=tf.float32), } dp_partitioner = partitioning.DataParallelPartitioner() - embedder = sparsecore.SparsecoreEmbedder( + sparsecore_config = sparsecore.SparsecoreConfig( specs={ "a": sparsecore.EmbeddingSpec( input_dim=100, - embedding_dim=16, + embedding_dim=64, combiner="mean", weight_name="w", ), "b": sparsecore.EmbeddingSpec( input_dim=100, - embedding_dim=16, - max_sequence_length=10, + embedding_dim=64, + max_sequence_length=16, ), }, optimizer=embedding_spec.AdagradOptimizerSpec(learning_rate=0.01), ) - preprocessor = embedder.make_preprocessor(32) - layer = embedder.make_sparsecore_module() + preprocessor = sparsecore.SparsecorePreprocessor(sparsecore_config, 32) + layer = sparsecore.SparsecoreEmbed(sparsecore_config) sc_inputs = dp_partitioner.shard_inputs(preprocessor(inputs)) - sc_vars = dp_partitioner.partition_init(functools.partial(layer.init, k4))( - sc_inputs - ) + sc_vars = dp_partitioner.partition_init( + functools.partial(layer.init, jax.random.key(0)), + abstract_batch=sc_inputs, + )(sc_inputs) def step(inputs, params): return layer.apply(params, inputs) @@ -70,8 +77,33 @@ def step(inputs, params): p_step = dp_partitioner.partition_step(step, training=False) sparsecore_activations = jax.device_get(p_step(sc_inputs, sc_vars)) - self.assertEqual(sparsecore_activations["a"].shape, (32, 16)) - self.assertEqual(sparsecore_activations["b"].shape, (32, 10, 16)) + self.assertEqual(sparsecore_activations["a"].shape, (32, 64)) + self.assertEqual(sparsecore_activations["b"].shape, (32, 16, 64)) + + tables = sparsecore.fetch_tables( + sparsecore_config, + sc_vars["params"][sparsecore.EMBEDDING_PARAM_NAME], + donate=False, + ) + + activations = sparsecore.cpu_lookup(sparsecore_config, tables, inputs) + np.testing.assert_allclose( + sparsecore_activations["a"], activations["a"], rtol=1e-5, atol=1e-5 + ) + np.testing.assert_allclose( + sparsecore_activations["b"], activations["b"], rtol=1e-5, atol=1e-5 + ) + + np.testing.assert_allclose( + sparsecore.gather_table( + sparsecore_config, + sc_vars["params"][sparsecore.EMBEDDING_PARAM_NAME], + "a", + ), + tables["a"], + rtol=1e-5, + atol=1e-5, + ) if __name__ == "__main__":