Skip to content

Commit b4441ee

Browse files
authored
Overhaul and stabilize AOT kernel launching mechanism (#136)
* Move the `_run_launch_func` helper inside `_generate_launch_func` and have it return the wrapper function directly * Move AOT kernel launch utilities from `tests/test_aot.py` to `src/ninetoothed/aot.py` * Return launch function directly from `ninetoothed.aot` and `ninetoothed.build` * Update the kernel launching mechanism to accept `ninetoothed` data types * Check kernel launch return value in `_run_launch_func` * Keep reference to original `torch.Tensor` in `_ArgumentTensor` * Use current stream for kernel launch * Load library via temporary copy
1 parent ee46030 commit b4441ee

3 files changed

Lines changed: 106 additions & 88 deletions

File tree

src/ninetoothed/aot.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import ast
2+
import ctypes
3+
import itertools
24
import pathlib
35
import re
6+
import shutil
47
import subprocess
58
import tempfile
69
import textwrap
@@ -39,6 +42,8 @@ def aot(
3942
with open(output_path, "w") as f:
4043
f.write(output_content)
4144

45+
return _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)
46+
4247

4348
def _aot(func, caller, kernel_name, num_warps, num_stages):
4449
def _find_tensor_by_source_name(tensors, name):
@@ -351,6 +356,25 @@ def visit_Lambda(self, node):
351356
return node
352357

353358

359+
class _ArgumentTensor(ctypes.Structure):
360+
_fields_ = [
361+
("data", ctypes.c_void_p),
362+
("shape", ctypes.POINTER(ctypes.c_uint64)),
363+
("strides", ctypes.POINTER(ctypes.c_int64)),
364+
]
365+
366+
@staticmethod
367+
def from_torch_tensor(tensor):
368+
data = ctypes.c_void_p(tensor.data_ptr())
369+
shape = (ctypes.c_uint64 * len(tensor.shape))(*tensor.shape)
370+
strides = (ctypes.c_int64 * len(tensor.stride()))(*tensor.stride())
371+
372+
arg_tensor = _ArgumentTensor(data, shape, strides)
373+
arg_tensor._torch_tensor = tensor
374+
375+
return arg_tensor
376+
377+
354378
def _compile(path, name, signature, grid, num_warps, num_stages):
355379
with tempfile.TemporaryDirectory() as temp_dir:
356380
output_dir = pathlib.Path(temp_dir)
@@ -389,3 +413,65 @@ def _compile(path, name, signature, grid, num_warps, num_stages):
389413
output_contents[file.name.replace(output_name, name)] = f.read()
390414

391415
return signature_hash, output_contents
416+
417+
418+
def _generate_launch_func(kernel_name, output_dir):
419+
import torch
420+
421+
output_dir = pathlib.Path(output_dir)
422+
423+
_compile_library(kernel_name, output_dir)
424+
library = _load_library(kernel_name, output_dir)
425+
launch_func_name = f"launch_{kernel_name}"
426+
launch_func = getattr(library, launch_func_name)
427+
428+
def _run_launch_func(*args, **kwargs):
429+
arguments = []
430+
431+
for arg in itertools.chain(args, kwargs.values()):
432+
if isinstance(arg, torch.Tensor):
433+
argument = _ArgumentTensor.from_torch_tensor(arg)
434+
elif isinstance(arg, str) and arg in _DTYPE_MAPPING:
435+
argument = tuple(_DTYPE_MAPPING.keys()).index(arg)
436+
else:
437+
argument = arg
438+
439+
arguments.append(argument)
440+
441+
result = launch_func(
442+
ctypes.c_void_p(torch.cuda.current_stream().cuda_stream), *arguments
443+
)
444+
445+
if result != 0:
446+
raise RuntimeError(f"Kernel launch failed with error code: {result}.")
447+
448+
return _run_launch_func
449+
450+
451+
def _compile_library(kernel_name, output_dir):
452+
command = [
453+
"nvcc",
454+
"-shared",
455+
"-Xcompiler",
456+
"-fPIC",
457+
"-lcuda",
458+
"-o",
459+
output_dir / f"{kernel_name}.so",
460+
] + list(output_dir.glob(f"{kernel_name}*.cpp"))
461+
462+
subprocess.run(command, check=True)
463+
464+
465+
def _load_library(kernel_name, kernel_dir):
466+
suffix = ".so"
467+
468+
original_path = kernel_dir / f"{kernel_name}{suffix}"
469+
470+
with tempfile.NamedTemporaryFile(suffix=suffix) as temp_file:
471+
temp_path = temp_file.name
472+
473+
shutil.copy(original_path, temp_path)
474+
475+
library = ctypes.CDLL(temp_path)
476+
477+
return library

src/ninetoothed/build.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import pathlib
66

77
import ninetoothed
8-
from ninetoothed.aot import _DTYPE_MAPPING, _HEADER_PATH, _MACRO_MAPPING
8+
from ninetoothed.aot import (
9+
_DTYPE_MAPPING,
10+
_HEADER_PATH,
11+
_MACRO_MAPPING,
12+
_generate_launch_func,
13+
)
914

1015

1116
def build(premake, configs, *, caller=None, kernel_name=None, output_dir=None):
@@ -102,6 +107,8 @@ def build(premake, configs, *, caller=None, kernel_name=None, output_dir=None):
102107
(output_dir / source_file_name).write_text(source_content)
103108
(output_dir / header_file_name).write_text(header_content)
104109

110+
return _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)
111+
105112

106113
def _make(premake, config, caller, kernel_name, output_dir):
107114
args, kwargs, compilation_configs = config

tests/test_aot.py

Lines changed: 12 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
import ctypes
21
import functools
3-
import itertools
4-
import pathlib
5-
import subprocess
62

73
import pytest
84
import torch
@@ -15,7 +11,6 @@
1511
import tests.test_conv2d as conv2d
1612
import tests.test_matmul as matmul
1713
from ninetoothed import Tensor
18-
from ninetoothed.aot import _DTYPE_MAPPING
1914
from tests.utils import get_available_devices
2015

2116

@@ -40,7 +35,7 @@ def _application(input, other, output):
4035
kernel_name = f"add{_generate_kernel_name_suffix()}"
4136
output_dir = ninetoothed.generation.CACHE_DIR
4237

43-
ninetoothed.make(
38+
kernel = ninetoothed.make(
4439
_arrangement,
4540
_application,
4641
tensors,
@@ -49,8 +44,6 @@ def _application(input, other, output):
4944
output_dir=output_dir,
5045
)
5146

52-
launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)
53-
5447
shape = (size,)
5548

5649
if test_multi_device:
@@ -67,7 +60,7 @@ def _application(input, other, output):
6760
other = torch.randn(shape, dtype=dtype, device=device)
6861
output = torch.empty_like(input)
6962

70-
_run_launch_func(launch_func, input, other, output)
63+
kernel(input, other, output)
7164

7265
expected = torch.add(input, other)
7366

@@ -93,7 +86,7 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol):
9386
kernel_name = f"addmm{_generate_kernel_name_suffix()}"
9487
output_dir = ninetoothed.generation.CACHE_DIR
9588

96-
ninetoothed.make(
89+
kernel = ninetoothed.make(
9790
arrangement,
9891
application,
9992
tensors,
@@ -102,8 +95,6 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol):
10295
output_dir=output_dir,
10396
)
10497

105-
launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)
106-
10798
input = torch.randn((m, n), dtype=dtype, device=device)
10899
mat1 = torch.randn((m, k), dtype=dtype, device=device)
109100
mat2 = torch.randn((k, n), dtype=dtype, device=device)
@@ -113,7 +104,7 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol):
113104
(mat1.shape[0], mat2.shape[1]), dtype=mat1.dtype, device=mat1.device
114105
)
115106

116-
_run_launch_func(launch_func, input, mat1, mat2, beta, alpha, output)
107+
kernel(input, mat1, mat2, beta, alpha, output)
117108

118109
expected = torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
119110

@@ -155,7 +146,7 @@ def test_attention(
155146
kernel_name = f"attention{_generate_kernel_name_suffix()}"
156147
output_dir = ninetoothed.generation.CACHE_DIR
157148

158-
ninetoothed.make(
149+
kernel = ninetoothed.make(
159150
arrangement,
160151
application,
161152
tensors,
@@ -164,8 +155,6 @@ def test_attention(
164155
output_dir=output_dir,
165156
)
166157

167-
launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)
168-
169158
shape = (batch_size, num_heads, seq_len, emb_dim)
170159

171160
query = torch.randn(shape, dtype=dtype, device=device)
@@ -174,7 +163,7 @@ def test_attention(
174163
is_causal = torch.tensor(True)
175164
output = torch.empty(shape, dtype=dtype, device=device)
176165

177-
_run_launch_func(launch_func, query, key, value, is_causal, output)
166+
kernel(query, key, value, is_causal, output)
178167

179168
expected = F.scaled_dot_product_attention(
180169
query, key, value, is_causal=True, scale=1
@@ -200,7 +189,7 @@ def test_matmul(m, n, k, dtype, device, ninetoothed_dtype):
200189
kernel_name = f"matmul{_generate_kernel_name_suffix()}"
201190
output_dir = ninetoothed.generation.CACHE_DIR
202191

203-
ninetoothed.make(
192+
kernel = ninetoothed.make(
204193
arrangement,
205194
application,
206195
tensors,
@@ -209,13 +198,11 @@ def test_matmul(m, n, k, dtype, device, ninetoothed_dtype):
209198
output_dir=output_dir,
210199
)
211200

212-
launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)
213-
214201
lhs = torch.randn((m, k), dtype=dtype, device=device)
215202
rhs = torch.randn((k, n), dtype=dtype, device=device)
216203
output = torch.empty((lhs.shape[0], rhs.shape[1]), dtype=dtype, device=device)
217204

218-
_run_launch_func(launch_func, lhs, rhs, output)
205+
kernel(lhs, rhs, output)
219206

220207
expected = torch.matmul(lhs, rhs)
221208

@@ -266,7 +253,7 @@ def test_conv2d(
266253
((), {"block_size_m": 128, "block_size_n": 32, "block_size_k": 64}, {}),
267254
)
268255

269-
ninetoothed.build(
256+
kernel = ninetoothed.build(
270257
premake,
271258
configs,
272259
caller=caller,
@@ -276,7 +263,7 @@ def test_conv2d(
276263
else:
277264
arrangement, application, tensors = premake()
278265

279-
ninetoothed.make(
266+
kernel = ninetoothed.make(
280267
arrangement,
281268
application,
282269
tensors,
@@ -285,8 +272,6 @@ def test_conv2d(
285272
output_dir=output_dir,
286273
)
287274

288-
launch_func = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)
289-
290275
p = h - r + 1
291276
q = w - s + 1
292277

@@ -295,77 +280,17 @@ def test_conv2d(
295280
output = torch.empty(n, k, p, q, dtype=dtype, device=device)
296281

297282
if test_build:
298-
config = (
299-
tuple(_DTYPE_MAPPING.keys()).index(ninetoothed_dtype),
300-
constexpr_shapes,
301-
) + tuple(configs[0][1].values())
283+
config = (ninetoothed_dtype, constexpr_shapes) + tuple(configs[0][1].values())
302284
else:
303285
config = ()
304286

305-
_run_launch_func(launch_func, input, filter, output, *config)
287+
kernel(input, filter, output, *config)
306288

307289
expected = F.conv2d(input, filter)
308290

309291
assert torch.allclose(output, expected, rtol=rtol, atol=atol)
310292

311293

312-
class _ArgumentTensor(ctypes.Structure):
313-
_fields_ = [
314-
("data", ctypes.c_void_p),
315-
("shape", ctypes.POINTER(ctypes.c_uint64)),
316-
("strides", ctypes.POINTER(ctypes.c_int64)),
317-
]
318-
319-
@staticmethod
320-
def from_torch_tensor(tensor):
321-
data = ctypes.c_void_p(tensor.data_ptr())
322-
shape = (ctypes.c_uint64 * len(tensor.shape))(*tensor.shape)
323-
strides = (ctypes.c_int64 * len(tensor.stride()))(*tensor.stride())
324-
325-
return _ArgumentTensor(data, shape, strides)
326-
327-
328-
def _run_launch_func(launch_func, *args, **kwargs):
329-
stream = torch.cuda.Stream()
330-
331-
arguments = tuple(
332-
_ArgumentTensor.from_torch_tensor(arg) if isinstance(arg, torch.Tensor) else arg
333-
for arg in itertools.chain(args, kwargs.values())
334-
)
335-
336-
with torch.cuda.stream(stream):
337-
launch_func(ctypes.c_void_p(stream.cuda_stream), *arguments)
338-
339-
340-
def _generate_launch_func(kernel_name, output_dir):
341-
output_dir = pathlib.Path(output_dir)
342-
343-
_compile_library(kernel_name, output_dir)
344-
library = _load_library(kernel_name, output_dir)
345-
launch_func_name = f"launch_{kernel_name}"
346-
launch_func = getattr(library, launch_func_name)
347-
348-
return launch_func
349-
350-
351-
def _compile_library(kernel_name, output_dir):
352-
command = [
353-
"nvcc",
354-
"-shared",
355-
"-Xcompiler",
356-
"-fPIC",
357-
"-lcuda",
358-
"-o",
359-
output_dir / f"{kernel_name}.so",
360-
] + list(output_dir.glob(f"{kernel_name}*.cpp"))
361-
362-
subprocess.run(command, check=True)
363-
364-
365-
def _load_library(kernel_name, kernel_dir):
366-
return ctypes.CDLL(kernel_dir / f"{kernel_name}.so")
367-
368-
369294
def _generate_kernel_name_suffix():
370295
count = _generate_kernel_name_suffix._kernel_count
371296
_generate_kernel_name_suffix._kernel_count += 1

0 commit comments

Comments
 (0)