Skip to content

Commit eb22f3b

Browse files
Merge pull request #3956 from AI-Hypercomputer:debug-pathways-devices-fix
PiperOrigin-RevId: 918617201
2 parents 76315ad + ebed9dd commit eb22f3b

3 files changed

Lines changed: 11 additions & 1 deletion

File tree

.github/workflows/run_pathways_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ jobs:
6868
IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS: true
6969
JAX_PLATFORMS: "proxy"
7070
JAX_BACKEND_TARGET: "grpc://localhost:29000"
71+
TPU_VISIBLE_DEVICES: ""
7172
options: ${{ inputs.container_resource_option }}
7273
steps:
7374
- name: Checkout MaxText
@@ -85,7 +86,6 @@ jobs:
8586
source .venv/bin/activate
8687
maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null)
8788
uv pip install ${maxtext_wheel}[tpu] --resolution=lowest
88-
uv pip uninstall libtpu
8989
install_tpu_pre_train_extra_deps
9090
python3 --version
9191
python3 -m pip freeze

src/maxtext/utils/generate_param_only_checkpoint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def _generate_lora_decode_checkpoints(config, mesh):
137137
config.enable_checkpointing,
138138
config.async_checkpointing,
139139
config.checkpoint_period,
140+
use_ocdbt=config.checkpoint_storage_use_ocdbt,
141+
use_zarr3=config.checkpoint_storage_use_zarr3,
140142
)
141143

142144
lora_config, lora_state, lora_state_annotations = lora_utils.setup_initial_lora_state(
@@ -192,6 +194,8 @@ def generate_decode_checkpoint(config):
192194
config.enable_checkpointing,
193195
config.async_checkpointing,
194196
config.checkpoint_period,
197+
use_ocdbt=config.checkpoint_storage_use_ocdbt,
198+
use_zarr3=config.checkpoint_storage_use_zarr3,
195199
)
196200
# Read training state from config.load_paramaters_path
197201
max_logging.log(f"Read training checkpoint from: {config.load_full_state_path}")

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def _custom_iter(self):
6666
except AttributeError:
6767
pass
6868

69+
import os
70+
71+
if os.getenv("JAX_PLATFORMS") == "proxy":
72+
# Import maxtext early to register the pathways proxy backend before JAX is queried.
73+
import maxtext # pylint: disable=unused-import
74+
6975
try:
7076
_HAS_TPU = any(d.platform == "tpu" for d in jax.devices())
7177
except Exception: # pragma: no cover pylint: disable=broad-exception-caught

0 commit comments

Comments
 (0)