Skip to content

Commit 4317d1c

Browse files
authored
Merge pull request #145 from AI-Hypercomputer/ajkv/v6-training-update
Added training functionality for HSTU and DLRM on v6 TPU
2 parents 61c08ca + 72bb9d5 commit 4317d1c

12 files changed

Lines changed: 656 additions & 64 deletions

File tree

recml/core/data/iterator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from etils import epath
2222
import numpy as np
2323
import tensorflow as tf
24+
import jax
2425

2526

2627
Iterator = clu_data.DatasetIterator
@@ -69,15 +70,17 @@ def _maybe_to_numpy(
6970
return x
7071
if hasattr(x, "_numpy"):
7172
numpy = x._numpy() # pylint: disable=protected-access
72-
else:
73+
elif hasattr(x, "numpy"):
7374
numpy = x.numpy()
75+
else:
76+
return x
77+
7478
if isinstance(numpy, np.ndarray):
75-
# `numpy` shares the same underlying buffer as the `x` Tensor.
7679
# Tensors are expected to be immutable, so we disable writes.
7780
numpy.setflags(write=False)
7881
return numpy
7982

80-
return tf.nest.map_structure(_maybe_to_numpy, batch)
83+
return jax.tree.map(_maybe_to_numpy, batch)
8184

8285
@property
8386
def element_spec(self) -> clu_data.ElementSpec:
@@ -109,7 +112,7 @@ def _to_element_spec(
109112
)
110113
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))
111114

112-
element_spec = tf.nest.map_structure(_to_element_spec, batch)
115+
element_spec = jax.tree.map(_to_element_spec, batch)
113116
self._element_spec = element_spec
114117
return element_spec
115118

recml/core/ops/embedding_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SparsecoreParams:
3838
"""Embedding parameters."""
3939

4040
feature_specs: Nested[FeatureSpec]
41-
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh
41+
mesh: jax.sharding.Mesh
4242
data_axes: Sequence[str | None]
4343
embedding_axes: Sequence[str | None]
4444
sharding_strategy: str

recml/core/training/core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import abc
1717
from collections.abc import Mapping, Sequence
18+
import contextlib
1819
import dataclasses
1920
import enum
2021
from typing import Any, Generic, TypeVar
@@ -24,6 +25,13 @@
2425
from recml.core.data import iterator
2526
import tensorflow as tf
2627

28+
# Patch jax.spmd_mode if it doesn't exist (removed in newer JAX versions).
29+
if not hasattr(jax, "spmd_mode"):
30+
@contextlib.contextmanager
31+
def _spmd_mode(*args, **kwargs):
32+
del args, kwargs
33+
yield
34+
jax.spmd_mode = _spmd_mode
2735

2836
# pylint: disable=logging-fstring-interpolation
2937

recml/core/training/partitioning.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
"""Utilities for partitioning."""
1516

1617
import abc
@@ -22,7 +23,6 @@
2223
import jax
2324
import numpy as np
2425

25-
2626
PyTree = Any
2727
State = Any
2828
CreateStateFn = Callable[[PyTree], State]
@@ -67,7 +67,8 @@ class DataParallelPartitioner(Partitioner):
6767
"""Data parallel partitioner."""
6868

6969
def __init__(self, data_axis: str = "batch"):
70-
self.mesh = jax.make_mesh((jax.device_count(),), (data_axis,))
70+
devices = jax.devices()
71+
self.mesh = jax.sharding.Mesh(devices, (data_axis,))
7172
self.data_sharding = jax.sharding.NamedSharding(
7273
self.mesh, jax.sharding.PartitionSpec(data_axis)
7374
)
@@ -107,7 +108,7 @@ def _shard(x: np.ndarray) -> jax.Array:
107108
def partition_init(
108109
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
109110
) -> CreateStateFn:
110-
with jax.sharding.use_mesh(self.mesh):
111+
with self.mesh:
111112
if abstract_batch is not None:
112113
abstract_state = jax.eval_shape(init_fn, abstract_batch)
113114
specs = nn.get_partition_spec(abstract_state)
@@ -117,7 +118,7 @@ def partition_init(
117118
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)
118119

119120
def _wrapped_init(batch: PyTree) -> State:
120-
with jax.sharding.use_mesh(self.mesh):
121+
with self.mesh:
121122
state = init_fn(batch)
122123
state = _maybe_unbox_state(state)
123124
return state
@@ -130,15 +131,15 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
130131
jit_kws["out_shardings"] = (self.state_sharding, None)
131132
jit_kws["donate_argnums"] = (1,)
132133

133-
with jax.sharding.use_mesh(self.mesh):
134+
with self.mesh:
134135
step_fn = jax.jit(
135136
fn,
136137
in_shardings=(self.data_sharding, self.state_sharding),
137138
**jit_kws,
138139
)
139140

140141
def _wrapped_step(batch: PyTree, state: State) -> Any:
141-
with jax.sharding.use_mesh(self.mesh):
142+
with self.mesh:
142143
return step_fn(batch, state)
143144

144145
return _wrapped_step
@@ -190,7 +191,7 @@ def __init__(
190191
if axis_sizes[0] == -1:
191192
axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:])
192193

193-
self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
194+
self.mesh = jax.sharding.Mesh(devices, axis_names)
194195
self.rules = rules
195196
self.aot_compile = aot_compile
196197
self.options = options
@@ -213,12 +214,6 @@ def __init__(
213214
self.abstract_batch = None
214215
self.abstract_state = None
215216

216-
@property
217-
def mesh_context_manager(
218-
self,
219-
) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
220-
return jax.sharding.use_mesh
221-
222217
def shard_inputs(self, inputs: PyTree) -> PyTree:
223218
def _shard(x: np.ndarray) -> jax.Array:
224219
return jax.make_array_from_process_local_data(self.data_sharding, x)
@@ -234,7 +229,7 @@ def partition_init(
234229
" model parallel partitioner."
235230
)
236231

237-
with self.mesh_context_manager(self.mesh):
232+
with self.mesh:
238233
abstract_state = jax.eval_shape(init_fn, abstract_batch)
239234
specs = nn.get_partition_spec(abstract_state)
240235

@@ -247,7 +242,7 @@ def partition_init(
247242
compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding)
248243

249244
def _init(batch: PyTree) -> State:
250-
with self.mesh_context_manager(self.mesh):
245+
with self.mesh:
251246
state = compiled_init_fn(batch)
252247
state = _maybe_unbox_state(state)
253248
return state
@@ -265,7 +260,8 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
265260
else:
266261
jit_kws["out_shardings"] = None
267262

268-
with self.mesh_context_manager(self.mesh):
263+
264+
with self.mesh:
269265
step_fn = jax.jit(
270266
fn,
271267
in_shardings=(self.data_sharding, self.state_sharding),
@@ -286,7 +282,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
286282
)
287283

288284
def _step(batch: PyTree, state: State) -> Any:
289-
with self.mesh_context_manager(self.mesh):
285+
with self.mesh:
290286
return step_fn(batch, state)
291287

292288
return _step
@@ -302,4 +298,4 @@ def _maybe_unbox(x: Any) -> Any:
302298
_maybe_unbox,
303299
x,
304300
is_leaf=lambda k: isinstance(k, nn.Partitioned),
305-
)
301+
)

recml/core/training/partitioning_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def test_data_parallelism(
4040
self, partitioner_cls: type[partitioning.Partitioner]
4141
):
4242
if partitioner_cls is partitioning.ModelParallelPartitioner:
43-
kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1}
43+
devs = np.array(jax.devices()).reshape(-1, 1)
44+
kwargs = {"axes": [("data", -1), ("model", 1)], "dp_axes": 1, "devices": devs}
4445
else:
4546
kwargs = {}
4647
partitioner = partitioner_cls(**kwargs)
@@ -112,8 +113,12 @@ def _eval_step(
112113
)
113114

114115
def test_model_parallelism(self):
116+
devs = np.array(jax.devices()).reshape(1, -1)
117+
115118
partitioner = partitioning.ModelParallelPartitioner(
116-
axes=[("data", 1), ("model", jax.device_count())], dp_axes=1
119+
axes=[("data", 1), ("model", jax.device_count())],
120+
dp_axes=1,
121+
devices=devs
117122
)
118123

119124
inputs = np.zeros((128, 16), dtype=np.float32)

recml/examples/dlrm_experiment.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
import dataclasses
2020
from typing import Generic, Literal, TypeVar
2121

22+
import sys
23+
import os
24+
# Add the RecML folder to the system path
25+
sys.path.append(os.path.join(os.getcwd(), "../../../RecML"))
26+
os.environ["KERAS_BACKEND"] = "jax"
27+
2228
from etils import epy
2329
import fiddle as fdl
2430
import flax.linen as nn

recml/examples/dlrm_experiment_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
# limitations under the License.
1414
"""Tests for the DLRM experiment."""
1515

16+
import sys
17+
import os
18+
# Add the RecML folder to the system path
19+
sys.path.append(os.path.join(os.getcwd(), "../../../RecML"))
20+
os.environ["KERAS_BACKEND"] = "jax"
21+
1622
from absl.testing import absltest
1723
import fiddle as fdl
1824
from fiddle import selectors
@@ -32,8 +38,8 @@ def test_dlrm_experiment(self):
3238

3339
experiment = dlrm_experiment.experiment()
3440

35-
experiment.task.train_data.global_batch_size = 4
36-
experiment.task.eval_data.global_batch_size = 4
41+
experiment.task.train_data.global_batch_size = 128
42+
experiment.task.eval_data.global_batch_size = 128
3743
experiment.trainer.train_steps = 12
3844
experiment.trainer.steps_per_loop = 4
3945
experiment.trainer.steps_per_eval = 4

0 commit comments

Comments
 (0)