Skip to content

Commit f31c762

Browse files
committed
Support py314 free-threading
Signed-off-by: Ziheng Deng <zihengd@nvidia.com>
1 parent a2edb3f commit f31c762

12 files changed

Lines changed: 217 additions & 17 deletions

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,17 @@ Running Tests
149149
-------------
150150
cuTile uses the [pytest](https://pytest.org) framework for testing.
151151
Tests have extra dependencies, such as PyTorch, which can be installed with
152+
153+
For Python non-free-threading build:
152154
```
153155
pip install -r test/requirements.txt
154156
```
155157

158+
Or for Python free-threading build:
159+
```
160+
pip install -r test/requirements-ft.txt
161+
```
162+
156163
The tests are located in the [test/](test/) directory. To run a specific test file,
157164
for example `test_copy.py`, use the following command:
158165
```

cext/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,23 @@ if (MSVC)
1212
set(cext_compile_flags /GR- /GS- /EHs- -DNDEBUG)
1313
# /NODEFAULTLIB and /NOENTRY to disable dependency on runtime lib msvcrt.dll
1414
set(cext_link_flags /NODEFAULTLIB /NOENTRY)
15+
# set Py_GIL_DISABLED macro, at it is not defined automatically on Windows
16+
execute_process(
17+
COMMAND ${Python_EXECUTABLE} -c
18+
"import sysconfig; print(sysconfig.get_config_var('Py_GIL_DISABLED') or 0)"
19+
OUTPUT_VARIABLE PY_GIL_DISABLED
20+
OUTPUT_STRIP_TRAILING_WHITESPACE
21+
)
22+
if (PY_GIL_DISABLED EQUAL 1)
23+
list(APPEND cext_compile_flags -DPy_GIL_DISABLED=1)
24+
endif()
1525
else()
1626
set(cext_compile_flags -fno-exceptions -fno-rtti -fPIC -fvisibility=hidden)
27+
# For arm64 add flag to disable calls to out-of-line helpers for atomic operations
28+
# to avoid undefined reference error during compilation
29+
if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)$")
30+
list(APPEND cext_compile_flags -mno-outline-atomics)
31+
endif()
1732
set(cext_link_flags ${cext_compile_flags})
1833
set(nostdlib_flags -nostdlib -DNDEBUG -fno-builtin)
1934
endif()

cext/cuda_helper.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,21 @@ PyObject* destroy_stream(PyObject* self, PyObject* arg) {
143143
Py_RETURN_NONE;
144144
}
145145

146-
static decltype(cuLaunchKernelEx)* g_real_cuLaunchKernelEx;
147-
static PyObject* g_cuLaunchKernelEx_spy_callback;
146+
static decltype(cuLaunchKernelEx)* g_real_cuLaunchKernelEx; // Protected by the GIL or g_spy_mutex
147+
static PyObject* g_cuLaunchKernelEx_spy_callback; // Protected by the GIL or g_spy_mutex
148+
149+
#ifdef Py_GIL_DISABLED
150+
static PyMutex g_spy_mutex = {0};
151+
#endif
148152

149153
static CUresult shim_cuLaunchKernelEx(
150154
const CUlaunchConfig *config,
151155
CUfunction f,
152156
void** kernelParams,
153157
void** extra) {
154-
158+
#ifdef Py_GIL_DISABLED
159+
PyCriticalSectionGuard guard(&g_spy_mutex);
160+
#endif
155161
PyPtr res = steal(PyObject_CallFunction(
156162
g_cuLaunchKernelEx_spy_callback,
157163
"(K III III I K)",
@@ -167,6 +173,9 @@ static CUresult shim_cuLaunchKernelEx(
167173
}
168174

169175
static PyObject* spy_on_cuLaunchKernel_begin(PyObject* self, PyObject* arg) {
176+
#ifdef Py_GIL_DISABLED
177+
PyCriticalSectionGuard guard(&g_spy_mutex);
178+
#endif
170179
if (g_real_cuLaunchKernelEx)
171180
return PyErr_Format(PyExc_RuntimeError, "Already spying");
172181

@@ -181,6 +190,9 @@ static PyObject* spy_on_cuLaunchKernel_begin(PyObject* self, PyObject* arg) {
181190
}
182191

183192
static PyObject* spy_on_cuLaunchKernel_end(PyObject* self, PyObject* arg) {
193+
#ifdef Py_GIL_DISABLED
194+
PyCriticalSectionGuard guard(&g_spy_mutex);
195+
#endif
184196
if (!g_real_cuLaunchKernelEx)
185197
return PyErr_Format(PyExc_RuntimeError, "Not spying");
186198

cext/cuda_loader.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,16 @@ static Status cuda_loader_init(DriverApi& driver_api) {
7777

7878
static constexpr int MIN_DRIVER_VERSION = 13000;
7979

80+
#ifdef Py_GIL_DISABLED
81+
static PyMutex g_driver_api_mutex = {0};
82+
#endif
83+
8084
Result<const DriverApi*> get_driver_api() {
8185
static bool initialized;
8286
static DriverApi instance;
87+
#ifdef Py_GIL_DISABLED
88+
PyCriticalSectionGuard guard(&g_driver_api_mutex);
89+
#endif
8390
if (!initialized) {
8491
if (!cuda_loader_init(instance))
8592
return ErrorRaised;

cext/module.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ PyMODINIT_FUNC PyInit__cext() {
3636
PyPtr m = steal(PyModule_Create(&module_def));
3737
if (!m) return nullptr;
3838

39+
#ifdef Py_GIL_DISABLED
40+
if (PyUnstable_Module_SetGIL(m.get(), Py_MOD_GIL_NOT_USED) != 0 )
41+
return nullptr;
42+
#endif
43+
3944
if (!tile_kernel_init(m.get()))
4045
return nullptr;
4146

cext/py.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#pragma once
88

9+
#include "check.h"
910
#include "ref_ptr.h"
1011
#include <Python.h>
1112
#include <optional>
@@ -275,3 +276,22 @@ static inline PyPtr try_import(const char* modname, SavedException* exc = nullpt
275276
if (!ret && exc) *exc = save_raised_exception();
276277
return ret;
277278
}
279+
280+
#ifdef Py_GIL_DISABLED
281+
class PyCriticalSectionGuard {
282+
public:
283+
explicit PyCriticalSectionGuard(PyMutex* mutex) {
284+
CHECK(mutex);
285+
PyCriticalSection_BeginMutex(&_py_cs, mutex);
286+
}
287+
288+
~PyCriticalSectionGuard() {
289+
PyCriticalSection_End(&_py_cs);
290+
}
291+
292+
PyCriticalSectionGuard(const PyCriticalSectionGuard&) = delete;
293+
void operator=(const PyCriticalSectionGuard&) = delete;
294+
private:
295+
PyCriticalSection _py_cs;
296+
};
297+
#endif

cext/tile_kernel.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ static Word* push_single_word_cuarg(LaunchHelper& helper, Word word) {
377377
return ptr;
378378
}
379379

380-
static LaunchHelper* g_helper_freelist; // protected by the GIL
380+
static LaunchHelper* g_helper_freelist; // protected by the GIL or g_launch_mutex
381381

382382
namespace { struct LaunchHelperDeleter {
383383
void operator() (LaunchHelper* helper) const {
@@ -386,6 +386,10 @@ namespace { struct LaunchHelperDeleter {
386386
}
387387
}; }
388388

389+
#ifdef Py_GIL_DISABLED
390+
static PyMutex g_launch_mutex = {0};
391+
#endif
392+
389393
using LaunchHelperPtr = std::unique_ptr<LaunchHelper, LaunchHelperDeleter>;
390394

391395

@@ -1489,6 +1493,9 @@ struct CompareKey <Vec<PyTypeObject*>, Vec<PyPtr>> {
14891493
namespace { struct TileContext {
14901494
PyPtr config;
14911495
PyPtr autotune_cache;
1496+
#ifdef Py_GIL_DISABLED
1497+
PyMutex accessor_mutex = {0};
1498+
#endif
14921499

14931500
static PyTypeObject pytype;
14941501
}; }
@@ -1913,7 +1920,7 @@ static Result<CUstream> parse_stream(PyObject* py_stream) {
19131920

19141921
using StreamBufferPoolMap = HashMap<unsigned long long, StreamBufferPool*>;
19151922

1916-
// Protected by GIL.
1923+
// Protected by GIL or g_launch_mutex.
19171924
// We have no reliable way to detect when a context is destroyed, so we never clean these up.
19181925
static StreamBufferPoolMap* g_stream_buffer_pool_by_ctx_id;
19191926

@@ -2413,16 +2420,27 @@ static int TileContext_init(PyObject* self, PyObject* args, PyObject* kwargs) {
24132420

24142421

24152422
static PyObject * TileContext_get_config(PyObject* self, void *closure) {
2416-
return Py_NewRef(py_unwrap<TileContext>(self).config.get());
2423+
TileContext& context = py_unwrap<TileContext>(self);
2424+
#ifdef Py_GIL_DISABLED
2425+
PyCriticalSectionGuard guard(&context.accessor_mutex);
2426+
#endif
2427+
return Py_NewRef(context.config.get());
24172428
}
24182429

24192430

24202431
static PyObject * TileContext_get_autotune_cache(PyObject* self, void *closure) {
2421-
return Py_NewRef(py_unwrap<TileContext>(self).autotune_cache.get());
2432+
TileContext& context = py_unwrap<TileContext>(self);
2433+
#ifdef Py_GIL_DISABLED
2434+
PyCriticalSectionGuard guard(&context.accessor_mutex);
2435+
#endif
2436+
return Py_NewRef(context.autotune_cache.get());
24222437
}
24232438

24242439
static int TileContext_set_autotune_cache(PyObject* self, PyObject* value, void* closure) {
24252440
TileContext& context = py_unwrap<TileContext>(self);
2441+
#ifdef Py_GIL_DISABLED
2442+
PyCriticalSectionGuard guard(&context.accessor_mutex);
2443+
#endif
24262444

24272445
// `del ctx.autotune_cache` → set back to None
24282446
if (value == nullptr) {
@@ -2490,6 +2508,9 @@ PyTypeObject TileDispatcher::pytype = {
24902508
};
24912509

24922510
static PyObject* get_parameter_constraints_from_pyargs(PyObject* self, PyObject* args) {
2511+
#ifdef Py_GIL_DISABLED
2512+
PyCriticalSectionGuard guard(&g_launch_mutex);
2513+
#endif
24932514
PyObject* dispatcher_pyobj = nullptr;
24942515
PyObject* pyargs = nullptr;
24952516
PyObject* cconv = nullptr;
@@ -2683,6 +2704,9 @@ static Status parse_launch_args(PyObject* const* args, Py_ssize_t nargs, const c
26832704
static PyObject* launch_impl(PyObject* const* args, Py_ssize_t nargs,
26842705
PyObject* kwargs, const char* signature, bool with_block
26852706
) {
2707+
#ifdef Py_GIL_DISABLED
2708+
PyCriticalSectionGuard guard(&g_launch_mutex);
2709+
#endif
26862710
LaunchArgs launch_args;
26872711
if (!parse_launch_args(args, nargs, signature, with_block, &launch_args))
26882712
return nullptr;
@@ -2723,6 +2747,9 @@ static PyObject *launch_extended(PyObject *, PyObject *const *args,
27232747
#define BENCHMARK_SIGNATURE "_benchmark(stream, grid, kernel, pyargs_tuples, /)"
27242748

27252749
static PyObject* cuda_tile_benchmark(PyObject* mod, PyObject* const* args, Py_ssize_t nargs) {
2750+
#ifdef Py_GIL_DISABLED
2751+
PyCriticalSectionGuard guard(&g_launch_mutex);
2752+
#endif
27262753
LaunchArgs launch_args;
27272754
if (!parse_launch_args(args, nargs, BENCHMARK_SIGNATURE, false, &launch_args))
27282755
return nullptr;

changelog.d/support-py314.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Add Python3.14 support
2+
- Add Python3.14 free threading support, safe to launch kernels via different threads with GIL disabled.

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _cmake(self, build_dir: str, build_type: str, dlpack_path: str):
4242
cmake_cmd = ["cmake", "-B", build_dir, project_root,
4343
f"-DDLPACK_PATH={dlpack_path}",
4444
f"-DCMAKE_BUILD_TYPE={build_type}",
45+
f"-DPython_EXECUTABLE={sys.executable}",
4546
"-DCMAKE_POLICY_VERSION_MINIMUM=3.5"]
4647
if self.disable_internal:
4748
cmake_cmd.append("-DDISABLE_INTERNAL=1")

test/requirements-ft.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
setuptools==80.9.0
6+
pytest==9.0.1
7+
pytest-benchmark==5.1.0
8+
pytest-cov==6.2.1
9+
pytest-env==1.2.0
10+
flake8==7.3.0
11+
pytest-cov==6.2.1
12+
13+
numpy==2.2.0; python_version < "3.11"
14+
numpy==2.4.4; python_version >= "3.11"
15+
16+
jax[cuda13]==0.10.0; sys_platform == 'linux' and python_version >= "3.11"
17+
flatbuffers==25.12.19; sys_platform == 'linux' and python_version >= "3.11"
18+
19+
20+
--extra-index-url https://download.pytorch.org/whl/cu130
21+
torch==2.10.0

0 commit comments

Comments
 (0)