1- import ctypes
21import functools
3- import itertools
4- import pathlib
5- import subprocess
62
73import pytest
84import torch
1511import tests .test_conv2d as conv2d
1612import tests .test_matmul as matmul
1713from ninetoothed import Tensor
18- from ninetoothed .aot import _DTYPE_MAPPING
1914from tests .utils import get_available_devices
2015
2116
@@ -40,7 +35,7 @@ def _application(input, other, output):
4035 kernel_name = f"add{ _generate_kernel_name_suffix ()} "
4136 output_dir = ninetoothed .generation .CACHE_DIR
4237
43- ninetoothed .make (
38+ kernel = ninetoothed .make (
4439 _arrangement ,
4540 _application ,
4641 tensors ,
@@ -49,8 +44,6 @@ def _application(input, other, output):
4944 output_dir = output_dir ,
5045 )
5146
52- launch_func = _generate_launch_func (kernel_name = kernel_name , output_dir = output_dir )
53-
5447 shape = (size ,)
5548
5649 if test_multi_device :
@@ -67,7 +60,7 @@ def _application(input, other, output):
6760 other = torch .randn (shape , dtype = dtype , device = device )
6861 output = torch .empty_like (input )
6962
70- _run_launch_func ( launch_func , input , other , output )
63+ kernel ( input , other , output )
7164
7265 expected = torch .add (input , other )
7366
@@ -93,7 +86,7 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol):
9386 kernel_name = f"addmm{ _generate_kernel_name_suffix ()} "
9487 output_dir = ninetoothed .generation .CACHE_DIR
9588
96- ninetoothed .make (
89+ kernel = ninetoothed .make (
9790 arrangement ,
9891 application ,
9992 tensors ,
@@ -102,8 +95,6 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol):
10295 output_dir = output_dir ,
10396 )
10497
105- launch_func = _generate_launch_func (kernel_name = kernel_name , output_dir = output_dir )
106-
10798 input = torch .randn ((m , n ), dtype = dtype , device = device )
10899 mat1 = torch .randn ((m , k ), dtype = dtype , device = device )
109100 mat2 = torch .randn ((k , n ), dtype = dtype , device = device )
@@ -113,7 +104,7 @@ def test_addmm(m, n, k, dtype, device, ninetoothed_dtype, atol):
113104 (mat1 .shape [0 ], mat2 .shape [1 ]), dtype = mat1 .dtype , device = mat1 .device
114105 )
115106
116- _run_launch_func ( launch_func , input , mat1 , mat2 , beta , alpha , output )
107+ kernel ( input , mat1 , mat2 , beta , alpha , output )
117108
118109 expected = torch .addmm (input , mat1 , mat2 , beta = beta , alpha = alpha )
119110
@@ -155,7 +146,7 @@ def test_attention(
155146 kernel_name = f"attention{ _generate_kernel_name_suffix ()} "
156147 output_dir = ninetoothed .generation .CACHE_DIR
157148
158- ninetoothed .make (
149+ kernel = ninetoothed .make (
159150 arrangement ,
160151 application ,
161152 tensors ,
@@ -164,8 +155,6 @@ def test_attention(
164155 output_dir = output_dir ,
165156 )
166157
167- launch_func = _generate_launch_func (kernel_name = kernel_name , output_dir = output_dir )
168-
169158 shape = (batch_size , num_heads , seq_len , emb_dim )
170159
171160 query = torch .randn (shape , dtype = dtype , device = device )
@@ -174,7 +163,7 @@ def test_attention(
174163 is_causal = torch .tensor (True )
175164 output = torch .empty (shape , dtype = dtype , device = device )
176165
177- _run_launch_func ( launch_func , query , key , value , is_causal , output )
166+ kernel ( query , key , value , is_causal , output )
178167
179168 expected = F .scaled_dot_product_attention (
180169 query , key , value , is_causal = True , scale = 1
@@ -200,7 +189,7 @@ def test_matmul(m, n, k, dtype, device, ninetoothed_dtype):
200189 kernel_name = f"matmul{ _generate_kernel_name_suffix ()} "
201190 output_dir = ninetoothed .generation .CACHE_DIR
202191
203- ninetoothed .make (
192+ kernel = ninetoothed .make (
204193 arrangement ,
205194 application ,
206195 tensors ,
@@ -209,13 +198,11 @@ def test_matmul(m, n, k, dtype, device, ninetoothed_dtype):
209198 output_dir = output_dir ,
210199 )
211200
212- launch_func = _generate_launch_func (kernel_name = kernel_name , output_dir = output_dir )
213-
214201 lhs = torch .randn ((m , k ), dtype = dtype , device = device )
215202 rhs = torch .randn ((k , n ), dtype = dtype , device = device )
216203 output = torch .empty ((lhs .shape [0 ], rhs .shape [1 ]), dtype = dtype , device = device )
217204
218- _run_launch_func ( launch_func , lhs , rhs , output )
205+ kernel ( lhs , rhs , output )
219206
220207 expected = torch .matmul (lhs , rhs )
221208
@@ -266,7 +253,7 @@ def test_conv2d(
266253 ((), {"block_size_m" : 128 , "block_size_n" : 32 , "block_size_k" : 64 }, {}),
267254 )
268255
269- ninetoothed .build (
256+ kernel = ninetoothed .build (
270257 premake ,
271258 configs ,
272259 caller = caller ,
@@ -276,7 +263,7 @@ def test_conv2d(
276263 else :
277264 arrangement , application , tensors = premake ()
278265
279- ninetoothed .make (
266+ kernel = ninetoothed .make (
280267 arrangement ,
281268 application ,
282269 tensors ,
@@ -285,8 +272,6 @@ def test_conv2d(
285272 output_dir = output_dir ,
286273 )
287274
288- launch_func = _generate_launch_func (kernel_name = kernel_name , output_dir = output_dir )
289-
290275 p = h - r + 1
291276 q = w - s + 1
292277
@@ -295,77 +280,17 @@ def test_conv2d(
295280 output = torch .empty (n , k , p , q , dtype = dtype , device = device )
296281
297282 if test_build :
298- config = (
299- tuple (_DTYPE_MAPPING .keys ()).index (ninetoothed_dtype ),
300- constexpr_shapes ,
301- ) + tuple (configs [0 ][1 ].values ())
283+ config = (ninetoothed_dtype , constexpr_shapes ) + tuple (configs [0 ][1 ].values ())
302284 else :
303285 config = ()
304286
305- _run_launch_func ( launch_func , input , filter , output , * config )
287+ kernel ( input , filter , output , * config )
306288
307289 expected = F .conv2d (input , filter )
308290
309291 assert torch .allclose (output , expected , rtol = rtol , atol = atol )
310292
311293
312- class _ArgumentTensor (ctypes .Structure ):
313- _fields_ = [
314- ("data" , ctypes .c_void_p ),
315- ("shape" , ctypes .POINTER (ctypes .c_uint64 )),
316- ("strides" , ctypes .POINTER (ctypes .c_int64 )),
317- ]
318-
319- @staticmethod
320- def from_torch_tensor (tensor ):
321- data = ctypes .c_void_p (tensor .data_ptr ())
322- shape = (ctypes .c_uint64 * len (tensor .shape ))(* tensor .shape )
323- strides = (ctypes .c_int64 * len (tensor .stride ()))(* tensor .stride ())
324-
325- return _ArgumentTensor (data , shape , strides )
326-
327-
328- def _run_launch_func (launch_func , * args , ** kwargs ):
329- stream = torch .cuda .Stream ()
330-
331- arguments = tuple (
332- _ArgumentTensor .from_torch_tensor (arg ) if isinstance (arg , torch .Tensor ) else arg
333- for arg in itertools .chain (args , kwargs .values ())
334- )
335-
336- with torch .cuda .stream (stream ):
337- launch_func (ctypes .c_void_p (stream .cuda_stream ), * arguments )
338-
339-
340- def _generate_launch_func (kernel_name , output_dir ):
341- output_dir = pathlib .Path (output_dir )
342-
343- _compile_library (kernel_name , output_dir )
344- library = _load_library (kernel_name , output_dir )
345- launch_func_name = f"launch_{ kernel_name } "
346- launch_func = getattr (library , launch_func_name )
347-
348- return launch_func
349-
350-
351- def _compile_library (kernel_name , output_dir ):
352- command = [
353- "nvcc" ,
354- "-shared" ,
355- "-Xcompiler" ,
356- "-fPIC" ,
357- "-lcuda" ,
358- "-o" ,
359- output_dir / f"{ kernel_name } .so" ,
360- ] + list (output_dir .glob (f"{ kernel_name } *.cpp" ))
361-
362- subprocess .run (command , check = True )
363-
364-
365- def _load_library (kernel_name , kernel_dir ):
366- return ctypes .CDLL (kernel_dir / f"{ kernel_name } .so" )
367-
368-
369294def _generate_kernel_name_suffix ():
370295 count = _generate_kernel_name_suffix ._kernel_count
371296 _generate_kernel_name_suffix ._kernel_count += 1
0 commit comments