Skip to content

Commit dcaca2a

Browse files
[JAX] Try to use pre-downloaded dataset artifacts first (NVIDIA#2345)
* Try to use pre-downloaded dataset artifacts first Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * Set HF_HUB_OFFLINE to disable any network calls to HF when the pre-downloaded dataset is available Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
1 parent b6020e3 commit dcaca2a

File tree

6 files changed

+57
-18
lines changed

6 files changed

+57
-18
lines changed

examples/jax/encoder/common.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# See LICENSE for license information.
44
"""Shared functions for the encoder tests"""
55
from functools import lru_cache
6+
import os
7+
import pathlib
8+
import zipfile
69

710
import jax
811
import jax.numpy
@@ -120,12 +123,48 @@ def get_quantization_recipe_from_name_string(name: str):
120123
raise ValueError(f"Invalid quantization_recipe, got {name}")
121124

122125

123-
def hf_login_if_available():
124-
"""Login to HF hub if available"""
125-
try:
126-
from huggingface_hub import login
126+
@lru_cache(maxsize=None)
127+
def _get_example_artifacts_dir() -> pathlib.Path:
128+
"""Path to directory with pre-downloaded datasets"""
127129

128-
login()
129-
except Exception as e:
130-
print(e)
131-
pass
130+
# Check environment variable
131+
path = os.getenv("NVTE_TEST_CHECKPOINT_ARTIFACT_PATH")
132+
if path:
133+
return pathlib.Path(path).resolve()
134+
135+
# Fallback to path in root dir
136+
root_dir = pathlib.Path(__file__).resolve().parent.parent.parent
137+
return root_dir / "artifacts" / "examples" / "jax"
138+
139+
140+
def _unpack_cached_dataset(artifacts_dir: pathlib.Path, folder_name: str) -> None:
141+
"""Unpack a cached dataset if available"""
142+
dataset_dir = artifacts_dir / folder_name
143+
if not dataset_dir.exists():
144+
print(f"Cached dataset {folder_name} not found at {dataset_dir}, skipping unpack")
145+
return
146+
147+
# Disable any HF network calls since the dataset is cached locally
148+
os.environ["HF_HUB_OFFLINE"] = "1"
149+
150+
for filename in os.listdir(dataset_dir):
151+
filepath = dataset_dir / filename
152+
if not filename.endswith(".zip"):
153+
continue
154+
print(f"Unpacking cached dataset {folder_name} from {filepath}")
155+
156+
with zipfile.ZipFile(filepath, "r") as zip_ref:
157+
zip_ref.extractall(pathlib.Path.home() / ".cache" / "huggingface")
158+
print(
159+
f"Unpacked cached dataset {folder_name} to"
160+
f" {pathlib.Path.home() / '.cache' / 'huggingface'}"
161+
)
162+
163+
164+
# This is cached so we don't have to unpack datasets multiple times
165+
@lru_cache(maxsize=None)
166+
def unpack_cached_datasets_if_available() -> None:
167+
"""Unpack cached datasets if available"""
168+
artifacts_dir = _get_example_artifacts_dir()
169+
_unpack_cached_dataset(artifacts_dir, "mnist")
170+
_unpack_cached_dataset(artifacts_dir, "encoder")

examples/jax/encoder/test_model_parallel_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
is_bf16_supported,
2424
get_quantization_recipe_from_name_string,
2525
assert_params_sufficiently_sharded,
26-
hf_login_if_available,
26+
unpack_cached_datasets_if_available,
2727
)
2828
import transformer_engine.jax as te
2929
import transformer_engine.jax.cpp_extensions as tex
3030
import transformer_engine.jax.flax as te_flax
3131
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
3232

33-
hf_login_if_available()
33+
unpack_cached_datasets_if_available()
3434

3535
DEVICE_DP_AXIS = "data"
3636
DEVICE_TP_AXIS = "model"

examples/jax/encoder/test_multigpu_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
from common import (
2323
is_bf16_supported,
2424
get_quantization_recipe_from_name_string,
25-
hf_login_if_available,
25+
unpack_cached_datasets_if_available,
2626
)
2727
import transformer_engine.jax as te
2828
import transformer_engine.jax.cpp_extensions as tex
2929
import transformer_engine.jax.flax as te_flax
3030
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
3131

32-
hf_login_if_available()
32+
unpack_cached_datasets_if_available()
3333

3434
DEVICE_DP_AXIS = "data"
3535
PARAMS_KEY = "params"

examples/jax/encoder/test_multiprocessing_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
is_mxfp8_supported,
2828
is_nvfp4_supported,
2929
get_quantization_recipe_from_name_string,
30-
hf_login_if_available,
30+
unpack_cached_datasets_if_available,
3131
)
3232
import transformer_engine.jax as te
3333
import transformer_engine.jax.cpp_extensions as tex
3434
import transformer_engine.jax.flax as te_flax
3535

36-
hf_login_if_available()
36+
unpack_cached_datasets_if_available()
3737

3838
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3939
DEVICE_DP_AXIS = "data"

examples/jax/encoder/test_single_gpu_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
from common import (
2020
is_bf16_supported,
2121
get_quantization_recipe_from_name_string,
22-
hf_login_if_available,
22+
unpack_cached_datasets_if_available,
2323
)
2424
import transformer_engine.jax as te
2525
import transformer_engine.jax.flax as te_flax
2626
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
2727

28-
hf_login_if_available()
28+
unpack_cached_datasets_if_available()
2929

3030
PARAMS_KEY = "params"
3131
DROPOUT_KEY = "dropout"

examples/jax/mnist/test_single_gpu_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from encoder.common import (
2626
is_bf16_supported,
2727
get_quantization_recipe_from_name_string,
28-
hf_login_if_available,
28+
unpack_cached_datasets_if_available,
2929
)
3030

31-
hf_login_if_available()
31+
unpack_cached_datasets_if_available()
3232

3333
IMAGE_H = 28
3434
IMAGE_W = 28

0 commit comments

Comments
 (0)