Skip to content

Commit e664b16

Browse files
Merge pull request #4003 from ROCm:upstream-fix-decoupled-gcs-collection
PiperOrigin-RevId: 922845285
2 parents faef777 + 0a43924 commit e664b16

2 files changed

Lines changed: 25 additions & 4 deletions

File tree

src/maxtext/checkpoint_conversion/utils/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
from jax.experimental import multihost_utils
3737
from jaxtyping import Array
3838

39-
from google.cloud.storage import Client, transfer_manager
40-
4139
from safetensors import safe_open
4240
from safetensors.numpy import save_file as numpy_save_file
4341
from safetensors.numpy import save as numpy_save
@@ -50,9 +48,14 @@
5048

5149
from flax.training import train_state
5250
from maxtext.common import checkpointing
51+
from maxtext.common.gcloud_stub import gcs_storage
5352
from maxtext.utils import max_logging
5453
import orbax.checkpoint as ocp
5554

55+
_storage = gcs_storage()
56+
Client = _storage.Client
57+
transfer_manager = _storage.transfer_manager
58+
5659

5760
SAFE_TENSORS_CONFIG_FILE = "config.json"
5861
SAFE_TENSORS_WEIGHTS_FILE = "model.safetensors"

src/maxtext/common/gcloud_stub.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,27 @@ def bucket(self, *a, **k): # pylint: disable=unused-argument
237237
def list_blobs(self, *a, **k): # pylint: disable=unused-argument
238238
return iter([])
239239

240-
return SimpleNamespace(Client=_StubClient, _IS_STUB=True)
240+
def _stub_upload_many_from_filenames(*_a, **_k):
241+
"""No-op stub for transfer_manager.upload_many_from_filenames."""
242+
return []
243+
244+
transfer_manager_stub = SimpleNamespace(
245+
upload_many_from_filenames=_stub_upload_many_from_filenames,
246+
_IS_STUB=True,
247+
)
248+
249+
return SimpleNamespace(Client=_StubClient, transfer_manager=transfer_manager_stub, _IS_STUB=True)
241250

242251

243252
def gcs_storage():
244-
"""Return google.cloud.storage module or stub when decoupled or missing."""
253+
"""Return google.cloud.storage module (with transfer_manager attached) or stub.
254+
255+
The returned object always exposes both ``.Client`` and ``.transfer_manager``
256+
so callers can use ``storage.transfer_manager.upload_many_from_filenames(...)``
257+
without an extra import. ``transfer_manager`` is a submodule of
258+
``google.cloud.storage`` and is not auto-imported by ``from google.cloud
259+
import storage``; we explicitly import and attach it here.
260+
"""
245261
# In decoupled mode always prefer the stub, even if the library is installed,
246262
# to avoid accidental GCS calls in tests or local runs.
247263
if is_decoupled(): # fast path
@@ -250,7 +266,9 @@ def gcs_storage():
250266

251267
try: # pragma: no cover - attempt real import when not decoupled
252268
from google.cloud import storage # type: ignore # pylint: disable=import-outside-toplevel
269+
from google.cloud.storage import transfer_manager # type: ignore # pylint: disable=import-outside-toplevel
253270

271+
setattr(storage, "transfer_manager", transfer_manager)
254272
setattr(storage, "_IS_STUB", False)
255273
return storage
256274
except Exception: # ModuleNotFoundError / ImportError for partial installs # pylint: disable=broad-exception-caught

0 commit comments

Comments
 (0)