Skip to content

Commit 3664830

Browse files
committed
Isolate CI torch extension cache
1 parent f0bcdac commit 3664830

4 files changed

Lines changed: 34 additions & 16 deletions

File tree

.github/workflows/unit_tests.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,19 @@ jobs:
244244
--shell)"
245245
eval "$runtime_output"
246246
247+
test_cache_key="${{ matrix.test_script }}"
248+
test_cache_key="${test_cache_key//\//_}"
249+
test_cache_key="${test_cache_key//./_}"
250+
test_cache_key="${test_cache_key//-/_}"
251+
export GPTQMODEL_TORCH_EXTENSIONS_DIR="/tmp/gptqmodel/torch_extensions/${{ github.run_id }}/${{ github.run_attempt }}/${ENV_NAME}/${test_cache_key}"
252+
export TORCH_EXTENSIONS_DIR="$GPTQMODEL_TORCH_EXTENSIONS_DIR"
253+
mkdir -p "$GPTQMODEL_TORCH_EXTENSIONS_DIR"
254+
echo "GPTQMODEL_TORCH_EXTENSIONS_DIR=$GPTQMODEL_TORCH_EXTENSIONS_DIR" >> "$GITHUB_ENV"
255+
echo "TORCH_EXTENSIONS_DIR=$TORCH_EXTENSIONS_DIR" >> "$GITHUB_ENV"
256+
247257
echo "-- setting up env --"
248258
echo "env_name: $ENV_NAME"
259+
echo "torch extensions dir: $GPTQMODEL_TORCH_EXTENSIONS_DIR"
249260
# will clean later
250261
251262
echo "-- ls -ahl /opt/uv/venvs before --"

gptqmodel/exllamav3/ext.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import os
1212
import sys
1313
from pathlib import Path
14-
from typing import Optional
1514

1615
import torch
1716

@@ -112,20 +111,6 @@ def _exllamav3_include_paths() -> list[str]:
112111
)
113112

114113

115-
def _legacy_build_root() -> Optional[Path]:
116-
build_root = os.environ.get("GPTQMODEL_EXT_BUILD")
117-
if not build_root:
118-
return None
119-
return Path(build_root) / extension_name
120-
121-
122-
def _default_build_root() -> Path:
123-
legacy_root = _legacy_build_root()
124-
if legacy_root is not None:
125-
return legacy_root
126-
return default_torch_ops_build_root("exllamav3")
127-
128-
129114
def _extra_cflags() -> list[str]:
130115
if windows:
131116
flags = ["/O2", "/std:c++17"]
@@ -184,7 +169,7 @@ def _prepare_build_env() -> None:
184169
),
185170
sources=_source_files,
186171
build_root_env="GPTQMODEL_EXLLAMAV3_BUILD_ROOT",
187-
default_build_root=_default_build_root,
172+
default_build_root=lambda: default_torch_ops_build_root("exllamav3"),
188173
display_name="ExLlamaV3",
189174
extra_cflags=_extra_cflags,
190175
extra_cuda_cflags=_extra_cuda_cflags,

gptqmodel/utils/cpp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
# the best overall tradeoff across Marlin, AWQ, QQQ, ExLlama, and ParoQuant.
6161
_DEFAULT_NVCC_THREADS = "8"
6262
_GLOBAL_KERNEL_REBUILD_ENV = "GPTQMODEL_KERNEL_REBUILD"
63+
_TORCH_OPS_BUILD_ROOT_ENV = "GPTQMODEL_TORCH_EXTENSIONS_DIR"
6364

6465

6566
def _nvcc_path() -> Optional[str]:
@@ -271,6 +272,9 @@ def close(self, *, succeeded: bool, elapsed_seconds: Optional[float] = None) ->
271272
def default_torch_ops_build_root(subdir: str) -> Path:
272273
"""Return the default on-disk cache root for torch.ops JIT extensions."""
273274

275+
override_root = os.getenv(_TORCH_OPS_BUILD_ROOT_ENV)
276+
if override_root:
277+
return Path(override_root).expanduser() / subdir
274278
return Path.home() / ".cache" / "gptqmodel" / "torch_extensions" / subdir
275279

276280

tests/test_torch_ops_jit_extension.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,24 @@ def test_default_torch_ops_build_root_ignores_removed_global_override(monkeypatc
226226
)
227227

228228

229+
def test_default_torch_ops_build_root_respects_ci_override(monkeypatch):
230+
monkeypatch.setenv("GPTQMODEL_TORCH_EXTENSIONS_DIR", "/tmp/gptqmodel-ci")
231+
232+
assert cpp_module.default_torch_ops_build_root("marlin") == Path("/tmp/gptqmodel-ci") / "marlin"
233+
234+
235+
def test_torch_ops_jit_extension_prefers_explicit_build_root_over_global_default(monkeypatch, tmp_path):
236+
loader = _make_loader(
237+
tmp_path,
238+
default_build_root=lambda: cpp_module.default_torch_ops_build_root("unit_test_ops"),
239+
)
240+
241+
monkeypatch.setenv("GPTQMODEL_TORCH_EXTENSIONS_DIR", "/tmp/gptqmodel-ci")
242+
monkeypatch.setenv("UNIT_TEST_BUILD_ROOT", "/tmp/unit-test-override")
243+
244+
assert loader.base_build_root() == Path("/tmp/unit-test-override")
245+
246+
229247
def test_torch_ops_jit_extension_prefers_cached_binary(monkeypatch, tmp_path):
230248
"""Guard cache reuse so startup skips expensive JIT rebuilds when ops are already built."""
231249

0 commit comments

Comments
 (0)