Skip to content

Commit f5d0cce

Browse files
committed
Allow pytest to skip on optional modules by default
Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent 6b7f693 commit f5d0cce

File tree

6 files changed

+33
-7
lines changed

6 files changed

+33
-7
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ include = [
6565
[tool.setuptools.package-data]
6666
"cuda.tile" = ["VERSION"]
6767

68-
[tool.pytest_env]
69-
CUDA_TILE_COMPILER_TIMEOUT_SEC = "60"
7068

7169
[tool.uv]
7270
managed = true

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ python_functions = test_* bench_*
88
addopts = --strict-markers
99
markers =
1010
use_mlir: mark tests that depend on mlir which requires the "internal" extension
11+
env =
12+
CUDA_TILE_COMPILER_TIMEOUT_SEC=60

test/bench_rms_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pytest
88
import torch
99
import cuda.tile as ct
10-
import cuda.tile_experimental as ct_experimental
1110
import itertools
1211
from math import ceil
1312
from util import estimate_bench_iter, next_power_of_2, is_ampere_or_ada
@@ -17,6 +16,8 @@
1716
from functools import partial
1817
from types import SimpleNamespace
1918

19+
ct_experimental = pytest.importorskip("cuda.tile_experimental")
20+
2021

2122
@pytest.fixture(params=[
2223
(262144, 1024),

test/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,28 @@
1515
from cuda.tile._compile import _get_max_supported_bytecode_version
1616

1717

18+
def pytest_addoption(parser):
19+
parser.addoption(
20+
"--error-on-import-skip",
21+
action="store_true",
22+
default=False,
23+
help="Treat import-related skips as errors",
24+
)
25+
26+
27+
def pytest_configure(config):
28+
if config.getoption("error_on_import_skip", default=False):
29+
_original = pytest.importorskip
30+
31+
def strict_importorskip(modname, *args, **kwargs):
32+
try:
33+
return _original(modname, *args, **kwargs)
34+
except pytest.skip.Exception as e:
35+
pytest.fail(f"Required import skipped: {e}")
36+
37+
pytest.importorskip = strict_importorskip
38+
39+
1840
@cache
1941
def get_tileiras_version():
2042
return _get_max_supported_bytecode_version(tempfile.gettempdir())

test/test_autotuner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
from functools import partial
1010
from util import assert_equal
1111

12-
import cuda.tile_experimental._autotuner as autotuner_mod
13-
from cuda.tile_experimental import autotune_launch, clear_autotune_cache
1412
from cuda.tile._cext import default_tile_context
1513
from cuda.tile._exception import TileCompilerTimeoutError, TileCompilerExecutionError
1614

15+
ct_experimental = pytest.importorskip("cuda.tile_experimental")
16+
autotuner_mod = ct_experimental._autotuner
17+
autotune_launch = ct_experimental.autotune_launch
18+
clear_autotune_cache = ct_experimental.clear_autotune_cache
19+
1720

1821
@ct.kernel
1922
def dummy_kernel(x, TILE_SIZE: ct.Constant[int]):

test/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def _find_filecheck_bin() -> Optional[str]:
165165

166166

167167
def filecheck(bytecode_buf: bytearray, check_directive: str) -> None:
168-
from cuda.tile_internal._internal_cext import bytecode_to_mlir_text
169-
mlir_text = bytecode_to_mlir_text(bytecode_buf)
168+
mod = pytest.importorskip("cuda.tile_internal._internal_cext")
169+
mlir_text = mod.bytecode_to_mlir_text(bytecode_buf)
170170

171171
filecheck_bin = _find_filecheck_bin()
172172
with (

0 commit comments

Comments
 (0)