Skip to content

Commit 7ac9fb4

Browse files
Merge pull request #3503 from zhenying-liu:param-offload-fix
PiperOrigin-RevId: 889513863
2 parents bc6df9f + 55d2f21 commit 7ac9fb4

2 files changed

Lines changed: 8 additions & 1 deletion

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
355355
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,
356356
raw_grads,
357357
)
358+
if config.parameter_memory_host_offload:
359+
raw_grads = jax.device_put(
360+
raw_grads,
361+
max_utils.with_memory_kind(params_shardings, "device"),
362+
)
358363
intermediate_outputs = aux["intermediate_outputs"]
359364
total_weights = aux["total_weights"]
360365
moe_lb_loss = aux["moe_lb_loss"]

src/maxtext/utils/max_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import time
2626
from typing import Any
2727

28+
from packaging.version import Version
29+
2830
from etils import epath
2931
import flax
3032
import jax
@@ -82,7 +84,7 @@ def calculate_num_params_from_pytree(params):
8284
def device_space():
8385
"""Version guard for jax.memory.Space.Device."""
8486
# See b/436565838 for more.
85-
if jax.__version__ >= "0.7.1":
87+
if Version(jax.__version__) >= Version("0.7.1"):
8688
return jax.memory.Space.Device # pytype: disable=module-attr
8789
else:
8890
return jax._src.sharding_impls.TransferToMemoryKind("device") # pylint: disable=protected-access # pytype: disable=module-attr

0 commit comments

Comments
 (0)