Skip to content

Commit 0e836ce

Browse files
committed
Test cuda_core examples
1 parent a7670b6 commit 0e836ce

3 files changed

Lines changed: 19 additions & 63 deletions

File tree

cuda_core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ cu12 = ["cuda-bindings[all]==12.*"]
5656
cu13 = ["cuda-bindings[all]==13.*"]
5757

5858
[dependency-groups]
59-
test = ["cython>=3.2,<3.3", "setuptools", "pytest>=6.2.4", "pytest-randomly", "pytest-repeat", "pytest-rerunfailures"]
59+
test = ["cython>=3.2,<3.3", "setuptools", "pytest>=6.2.4", "pytest-randomly", "pytest-repeat", "pytest-rerunfailures", "cffi"]
6060
ml-dtypes = ["ml-dtypes>=0.5.4,<0.6.0"]
6161
test-cu12 = [ {include-group = "ml-dtypes" }, {include-group = "test" }, "cupy-cuda12x; python_version < '3.14'", "cuda-toolkit[cudart]==12.*"] # runtime headers needed by CuPy
6262
test-cu13 = [ {include-group = "ml-dtypes" }, {include-group = "test" }, "cupy-cuda13x; python_version < '3.14'", "cuda-toolkit[cudart]==13.*"] # runtime headers needed by CuPy
Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,33 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

44
# If we have subcategories of examples in the future, this file can be split along those lines
55

66
import glob
77
import os
8+
import subprocess
9+
import sys
810

911
import pytest
1012

11-
from cuda.core import Device
12-
13-
from .utils import run_example
13+
DISALLOW_LIST = {
14+
"gl_interop_plasma.py", # requires a display
15+
}
1416

1517
samples_path = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
16-
sample_files = glob.glob(samples_path + "**/*.py", recursive=True)
18+
sample_files = [
19+
x
20+
for x in (os.path.basename(x) for x in glob.glob(samples_path + "**/*.py", recursive=True))
21+
if x not in DISALLOW_LIST
22+
]
1723

1824

1925
@pytest.mark.parametrize("example", sample_files)
2026
class TestExamples:
21-
def test_example(self, example, deinit_cuda):
22-
run_example(samples_path, example)
23-
if Device().device_id != 0:
24-
Device(0).set_current()
27+
def test_example(self, example):
28+
example_path = os.path.join(samples_path, example)
29+
process = subprocess.run([sys.executable, example_path], capture_output=True) # noqa: S603
30+
if process.returncode != 0:
31+
print(process.stdout.decode())
32+
print(process.stderr.decode(), file=sys.stderr)
33+
raise AssertionError(f"Example failed with return code {process.returncode} ({example})")

cuda_core/tests/example_tests/utils.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)