Skip to content

Commit 72bb9d5

Browse files
committed
Removed the global mesh
1 parent ff021b7 commit 72bb9d5

4 files changed

Lines changed: 12 additions & 50 deletions

File tree

recml/core/training/mesh_context.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

recml/core/training/partitioning.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import flax.linen as nn
2323
import jax
2424
import numpy as np
25-
from recml.core.training import mesh_context
26-
2725

2826
PyTree = Any
2927
State = Any
@@ -112,7 +110,6 @@ def partition_init(
112110
) -> CreateStateFn:
113111
with self.mesh:
114112
if abstract_batch is not None:
115-
mesh_context.set_global_mesh(self.mesh)
116113
abstract_state = jax.eval_shape(init_fn, abstract_batch)
117114
specs = nn.get_partition_spec(abstract_state)
118115
self.state_sharding = jax.tree.map(
@@ -135,7 +132,6 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
135132
jit_kws["donate_argnums"] = (1,)
136133

137134
with self.mesh:
138-
mesh_context.set_global_mesh(self.mesh)
139135
step_fn = jax.jit(
140136
fn,
141137
in_shardings=(self.data_sharding, self.state_sharding),
@@ -195,7 +191,6 @@ def __init__(
195191
if axis_sizes[0] == -1:
196192
axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:])
197193

198-
# self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
199194
self.mesh = jax.sharding.Mesh(devices, axis_names)
200195
self.rules = rules
201196
self.aot_compile = aot_compile
@@ -235,7 +230,6 @@ def partition_init(
235230
)
236231

237232
with self.mesh:
238-
mesh_context.set_global_mesh(self.mesh)
239233
abstract_state = jax.eval_shape(init_fn, abstract_batch)
240234
specs = nn.get_partition_spec(abstract_state)
241235

@@ -268,7 +262,6 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
268262

269263

270264
with self.mesh:
271-
mesh_context.set_global_mesh(self.mesh)
272265
step_fn = jax.jit(
273266
fn,
274267
in_shardings=(self.data_sharding, self.state_sharding),

recml/examples/train_hstu_jax.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,6 @@ def _compute_metrics(self, loss, logits, targets):
184184
targets = jnp.squeeze(targets)
185185
metrics = {"loss": clu_metrics.Average.from_model_output(loss)}
186186

187-
# def get_acc(k):
188-
# _, top_k_indices = jax.nn.top_k(logits, k)
189-
# correct = jnp.sum(top_k_indices == targets[:, None], axis=-1)
190-
# return jnp.mean(correct)
191-
192-
# metrics["HR_10"] = clu_metrics.Average.from_model_output(get_acc(10))
193-
# metrics["HR_50"] = clu_metrics.Average.from_model_output(get_acc(50))
194-
# metrics["HR_200"] = clu_metrics.Average.from_model_output(get_acc(200))
195187
return metrics
196188

197189
def experiment() -> fdl.Config[recml.Experiment]:

recml/layers/linen/sparsecore.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828
from recml.core.ops import embedding_ops
2929
import tensorflow as tf
3030

31-
from recml.core.training import mesh_context
32-
33-
3431
with epy.lazy_imports():
3532
# pylint: disable=g-import-not-at-top
3633
from jax_tpu_embedding.sparsecore.lib.flax.linen import embed
@@ -177,7 +174,6 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array:
177174
dataclasses.field(default=lambda n, bs: bs)
178175
)
179176

180-
# Optional device information.
181177
local_device_count: int = dataclasses.field(
182178
default_factory=jax.local_device_count
183179
)
@@ -369,19 +365,9 @@ class SparsecoreEmbed(nn.Module):
369365
"""
370366

371367
sparsecore_config: SparsecoreConfig
372-
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None = None
373-
374-
def get_mesh(self) -> jax.sharding.Mesh:
375-
# Try to get the mesh from our custom global context
376-
mesh = mesh_context.get_global_mesh()
377-
378-
if mesh is None:
379-
raise ValueError(
380-
"No global mesh found. Make sure to call "
381-
"`partitioning.partition_init` (which sets the mesh) "
382-
"before initializing SparseCore."
383-
)
384-
return mesh
368+
mesh: jax.sharding.Mesh = dataclasses.field(
369+
default_factory=lambda: jax.sharding.Mesh(jax.devices(), ('batch',))
370+
)
385371

386372
def get_sharding_axis(
387373
self, mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh
@@ -391,25 +377,25 @@ def get_sharding_axis(
391377
return self.sparsecore_config.sharding_axis
392378

393379
def setup(self):
394-
mesh = self.get_mesh()
395-
sharding_axis_name = self.get_sharding_axis(mesh)
396380

381+
sharding_axis_name = self.get_sharding_axis(self.mesh)
382+
397383
initializer = functools.partial(
398384
embedding.init_embedding_variables,
399385
table_specs=embedding.get_table_specs(
400386
self.sparsecore_config.feature_specs
401387
),
402388
global_sharding=jax.sharding.NamedSharding(
403-
mesh, jax.sharding.PartitionSpec(sharding_axis_name, None)
389+
self.mesh, jax.sharding.PartitionSpec(sharding_axis_name, None)
404390
),
405391
num_sparsecore_per_device=self.sparsecore_config.num_sc_per_device,
406392
# We need to by-pass the mesh check to allow using an abstract mesh.
407-
bypass_mesh_check=isinstance(mesh, jax.sharding.AbstractMesh),
393+
bypass_mesh_check=isinstance(self.mesh, jax.sharding.AbstractMesh),
408394
)
409395
self.embedding_table = self.param(
410396
name=EMBEDDING_PARAM_NAME,
411397
init_fn=embed.with_sparsecore_layout(
412-
initializer, (sharding_axis_name,), mesh # type: ignore
398+
initializer, (sharding_axis_name,), self.mesh # type: ignore
413399
),
414400
)
415401

@@ -426,12 +412,13 @@ def __call__(
426412
Returns:
427413
The activations structure with the same structure as specs.
428414
"""
429-
mesh = self.get_mesh()
430-
sharding_axis_name = self.get_sharding_axis(mesh)
415+
# mesh = self.get_mesh()
416+
sharding_axis_name = self.get_sharding_axis(self.mesh)
417+
431418
activations = embedding_ops.sparsecore_lookup(
432419
embedding_ops.SparsecoreParams(
433420
feature_specs=self.sparsecore_config.feature_specs,
434-
mesh=mesh,
421+
mesh=self.mesh,
435422
data_axes=(sharding_axis_name,),
436423
embedding_axes=(sharding_axis_name, None),
437424
sharding_strategy=self.sparsecore_config.sharding_strategy,

0 commit comments

Comments
 (0)