File tree Expand file tree Collapse file tree 6 files changed +33
-7
lines changed
Expand file tree Collapse file tree 6 files changed +33
-7
lines changed Original file line number Diff line number Diff line change @@ -65,8 +65,6 @@ include = [
6565[tool .setuptools .package-data ]
6666"cuda.tile" = [" VERSION" ]
6767
68- [tool .pytest_env ]
69- CUDA_TILE_COMPILER_TIMEOUT_SEC = " 60"
7068
7169[tool .uv ]
7270managed = true
Original file line number Diff line number Diff line change @@ -8,3 +8,5 @@ python_functions = test_* bench_*
88addopts = --strict-markers
99markers =
1010 use_mlir: mark tests that depend on mlir which requires the " internal" extension
11+ env =
12+ CUDA_TILE_COMPILER_TIMEOUT_SEC =60
Original file line number Diff line number Diff line change 77import pytest
88import torch
99import cuda .tile as ct
10- import cuda .tile_experimental as ct_experimental
1110import itertools
1211from math import ceil
1312from util import estimate_bench_iter , next_power_of_2 , is_ampere_or_ada
1716from functools import partial
1817from types import SimpleNamespace
1918
19+ ct_experimental = pytest .importorskip ("cuda.tile_experimental" )
20+
2021
2122@pytest .fixture (params = [
2223 (262144 , 1024 ),
Original file line number Diff line number Diff line change 1515from cuda .tile ._compile import _get_max_supported_bytecode_version
1616
1717
18+ def pytest_addoption (parser ):
19+ parser .addoption (
20+ "--error-on-import-skip" ,
21+ action = "store_true" ,
22+ default = False ,
23+ help = "Treat import-related skips as errors" ,
24+ )
25+
26+
27+ def pytest_configure (config ):
28+ if config .getoption ("error_on_import_skip" , default = False ):
29+ _original = pytest .importorskip
30+
31+ def strict_importorskip (modname , * args , ** kwargs ):
32+ try :
33+ return _original (modname , * args , ** kwargs )
34+ except pytest .skip .Exception as e :
35+ pytest .fail (f"Required import skipped: { e } " )
36+
37+ pytest .importorskip = strict_importorskip
38+
39+
1840@cache
1941def get_tileiras_version ():
2042 return _get_max_supported_bytecode_version (tempfile .gettempdir ())
Original file line number Diff line number Diff line change 99from functools import partial
1010from util import assert_equal
1111
12- import cuda .tile_experimental ._autotuner as autotuner_mod
13- from cuda .tile_experimental import autotune_launch , clear_autotune_cache
1412from cuda .tile ._cext import default_tile_context
1513from cuda .tile ._exception import TileCompilerTimeoutError , TileCompilerExecutionError
1614
15+ ct_experimental = pytest .importorskip ("cuda.tile_experimental" )
16+ autotuner_mod = ct_experimental ._autotuner
17+ autotune_launch = ct_experimental .autotune_launch
18+ clear_autotune_cache = ct_experimental .clear_autotune_cache
19+
1720
1821@ct .kernel
1922def dummy_kernel (x , TILE_SIZE : ct .Constant [int ]):
Original file line number Diff line number Diff line change @@ -165,8 +165,8 @@ def _find_filecheck_bin() -> Optional[str]:
165165
166166
167167def filecheck (bytecode_buf : bytearray , check_directive : str ) -> None :
168- from cuda .tile_internal ._internal_cext import bytecode_to_mlir_text
169- mlir_text = bytecode_to_mlir_text (bytecode_buf )
168+ mod = pytest . importorskip ( " cuda.tile_internal._internal_cext" )
169+ mlir_text = mod . bytecode_to_mlir_text (bytecode_buf )
170170
171171 filecheck_bin = _find_filecheck_bin ()
172172 with (
You can’t perform that action at this time.
0 commit comments