Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recml/examples/dlrm_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def sparsecore_config(self) -> sparsecore.SparsecoreConfig:
for f in self.features.sparse_features()
},
optimizer=self.embedding_optimizer,
allow_id_dropping=True,
)
object.__setattr__(self, '_sparsecore_config', sparsecore_config)
return sparsecore_config
Expand Down
5 changes: 4 additions & 1 deletion recml/layers/linen/sparsecore.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class SparsecoreConfig:
sharding_strategy: The sharding strategy to use for the embedding table.
Defaults to 'MOD' sharding. See the sparsecore documentation for more
details.
allow_id_dropping: Whether to allow dropping of IDs that do not fit within
the XLA buffers allocated for each partition. Defaults to False.
num_sc_per_device: The number of sparsecores per Jax device. By default, a
fixed mapping is used to determine this based on device 0. This may fail
on newer TPU architectures if the mapping is not updated of if device 0 is
Expand Down Expand Up @@ -163,6 +165,7 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array:
optimizer: OptimizerSpec
sharding_axis: str | int = 0
sharding_strategy: str = 'MOD'
allow_id_dropping: bool = False

# TODO(aahil): Come up with better defaults / heuristics here.
max_ids_per_partition_fn: Callable[[str, int], int] = dataclasses.field(
Expand Down Expand Up @@ -339,7 +342,7 @@ def _to_np(x: Any) -> np.ndarray:
global_device_count=self.sparsecore_config.global_device_count,
num_sc_per_device=self.sparsecore_config.num_sc_per_device,
sharding_strategy=self.sparsecore_config.sharding_strategy,
allow_id_dropping=False,
allow_id_dropping=self.sparsecore_config.allow_id_dropping,
batch_number=self._batch_number,
)

Expand Down
Loading