Fix the gradients' device memory space when enabling parameter offloading#3503
Merged
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom Mar 26, 2026
Merged
Conversation
gobbleturk
approved these changes
Mar 25, 2026
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
NuojCheng
approved these changes
Mar 25, 2026
7ac9fb4
into
AI-Hypercomputer:main
35 of 36 checks passed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Fix gradient memory placement when parameter offloading is enabled with parameter_memory_host_offload=True.
When parameter offloading is active, model parameters live in host memory and are streamed to device during the forward pass. However, JAX's autodiff computes gradients in the same memory space as the inputs, meaning the backward pass produces gradients that are also on the host. Previously, these host-resident gradients were passed directly to gradient clipping and the optimizer update without being moved to device first, causing a runtime error when gradients are used in a device operation.
This PR adds a jax.device_put call immediately after the backward pass to explicitly move raw_grads back to device memory with the correct sharding before any further gradient processing. The fix is a one-time transfer per training step and is gated behind the existing parameter_memory_host_offload config flag.
Tests
The existing integration test test_gpu_parameter_offload in tests/integration/train_tests.py covers this fix. It runs 10 training steps with parameter_memory_host_offload=True on a GPU, exercising the backward pass and the gradient device_put added by this change.
pytest tests/integration/train_tests.py::GPUTest::test_gpu_parameter_offload -v
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.