2828from recml .core .ops import embedding_ops
2929import tensorflow as tf
3030
31- from recml .core .training import mesh_context
32-
33-
3431with epy .lazy_imports ():
3532 # pylint: disable=g-import-not-at-top
3633 from jax_tpu_embedding .sparsecore .lib .flax .linen import embed
@@ -177,7 +174,6 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array:
177174 dataclasses .field (default = lambda n , bs : bs )
178175 )
179176
180- # Optional device information.
181177 local_device_count : int = dataclasses .field (
182178 default_factory = jax .local_device_count
183179 )
@@ -369,19 +365,9 @@ class SparsecoreEmbed(nn.Module):
369365 """
370366
371367 sparsecore_config : SparsecoreConfig
372- mesh : jax .sharding .Mesh | jax .sharding .AbstractMesh | None = None
373-
374- def get_mesh (self ) -> jax .sharding .Mesh :
375- # Try to get the mesh from our custom global context
376- mesh = mesh_context .get_global_mesh ()
377-
378- if mesh is None :
379- raise ValueError (
380- "No global mesh found. Make sure to call "
381- "`partitioning.partition_init` (which sets the mesh) "
382- "before initializing SparseCore."
383- )
384- return mesh
368+ mesh : jax .sharding .Mesh = dataclasses .field (
369+ default_factory = lambda : jax .sharding .Mesh (jax .devices (), ('batch' ,))
370+ )
385371
386372 def get_sharding_axis (
387373 self , mesh : jax .sharding .Mesh | jax .sharding .AbstractMesh
@@ -391,25 +377,25 @@ def get_sharding_axis(
391377 return self .sparsecore_config .sharding_axis
392378
393379 def setup (self ):
394- mesh = self .get_mesh ()
395- sharding_axis_name = self .get_sharding_axis (mesh )
396380
381+ sharding_axis_name = self .get_sharding_axis (self .mesh )
382+
397383 initializer = functools .partial (
398384 embedding .init_embedding_variables ,
399385 table_specs = embedding .get_table_specs (
400386 self .sparsecore_config .feature_specs
401387 ),
402388 global_sharding = jax .sharding .NamedSharding (
403- mesh , jax .sharding .PartitionSpec (sharding_axis_name , None )
389+ self . mesh , jax .sharding .PartitionSpec (sharding_axis_name , None )
404390 ),
405391 num_sparsecore_per_device = self .sparsecore_config .num_sc_per_device ,
406392 # We need to by-pass the mesh check to allow using an abstract mesh.
407- bypass_mesh_check = isinstance (mesh , jax .sharding .AbstractMesh ),
393+ bypass_mesh_check = isinstance (self . mesh , jax .sharding .AbstractMesh ),
408394 )
409395 self .embedding_table = self .param (
410396 name = EMBEDDING_PARAM_NAME ,
411397 init_fn = embed .with_sparsecore_layout (
412- initializer , (sharding_axis_name ,), mesh # type: ignore
398+ initializer , (sharding_axis_name ,), self . mesh # type: ignore
413399 ),
414400 )
415401
@@ -426,12 +412,13 @@ def __call__(
426412 Returns:
427413 The activations structure with the same structure as specs.
428414 """
429- mesh = self .get_mesh ()
430- sharding_axis_name = self .get_sharding_axis (mesh )
415+ # mesh = self.get_mesh()
416+ sharding_axis_name = self .get_sharding_axis (self .mesh )
417+
431418 activations = embedding_ops .sparsecore_lookup (
432419 embedding_ops .SparsecoreParams (
433420 feature_specs = self .sparsecore_config .feature_specs ,
434- mesh = mesh ,
421+ mesh = self . mesh ,
435422 data_axes = (sharding_axis_name ,),
436423 embedding_axes = (sharding_axis_name , None ),
437424 sharding_strategy = self .sparsecore_config .sharding_strategy ,
0 commit comments