Skip to content

Fix the gradients' device memory space when enabling parameter offloading#3503

Merged
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
zhenying-liu:param-offload-fix
Mar 26, 2026
Merged

Fix the gradients' device memory space when enabling parameter offloading#3503
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
zhenying-liu:param-offload-fix

Conversation

@zhenying-liu
Copy link
Copy Markdown
Contributor

@zhenying-liu zhenying-liu commented Mar 25, 2026

Description

Fix gradient memory placement when parameter offloading is enabled with parameter_memory_host_offload=True.

When parameter offloading is active, model parameters live in host memory and are streamed to device during the forward pass. However, JAX's autodiff computes gradients in the same memory space as the inputs, meaning the backward pass produces gradients that are also on the host. Previously, these host-resident gradients were passed directly to gradient clipping and the optimizer update without being moved to device first, causing a runtime error when gradients are used in a device operation.

This PR adds a jax.device_put call immediately after the backward pass to explicitly move raw_grads back to device memory with the correct sharding before any further gradient processing. The fix is a one-time transfer per training step and is gated behind the existing parameter_memory_host_offload config flag.

Tests

The existing integration test test_gpu_parameter_offload in tests/integration/train_tests.py covers this fix. It runs 10 training steps with parameter_memory_host_offload=True on a GPU, exercising the backward pass and the gradient device_put added by this change.

pytest tests/integration/train_tests.py::GPUTest::test_gpu_parameter_offload -v

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 25, 2026

Codecov Report

❌ Patch coverage is 75.00000% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/max_utils.py 50.00% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@copybara-service copybara-service Bot merged commit 7ac9fb4 into AI-Hypercomputer:main Mar 26, 2026
35 of 36 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants