forked from NVIDIA/cuda-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpytorch_example.py
More file actions
110 lines (86 loc) · 3.08 KB
/
pytorch_example.py
File metadata and controls
110 lines (86 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
# ################################################################################
#
# This demo illustrates how to use `cuda.core` to compile a CUDA kernel
# and launch it using PyTorch tensors as inputs.
#
# ## Usage: pip install "cuda-core[cu12]"
# ## python pytorch_example.py
#
# ################################################################################
import sys
import torch
from cuda.core import Device, LaunchConfig, Program, ProgramOptions, launch
# SAXPY kernel - passing a as a pointer to avoid any type issues
code = """
template<typename T>
__global__ void saxpy_kernel(const T* a, const T* x, const T* y, T* out, size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < N) {
// Dereference a to get the scalar value
out[tid] = (*a) * x[tid] + y[tid];
}
}
"""
dev = Device()
dev.set_current()
# Get PyTorch's current stream
pt_stream = torch.cuda.current_stream()
print(f"PyTorch stream: {pt_stream}", file=sys.stderr)
# Create a wrapper class that implements __cuda_stream__
class PyTorchStreamWrapper:
def __init__(self, pt_stream):
self.pt_stream = pt_stream
def __cuda_stream__(self):
stream_id = self.pt_stream.cuda_stream
return (0, stream_id) # Return format required by CUDA Python
stream = dev.create_stream(PyTorchStreamWrapper(pt_stream))
# prepare program
program_options = ProgramOptions(std="c++11", arch=f"sm_{dev.arch}")
prog = Program(code, code_type="c++", options=program_options)
mod = prog.compile(
"cubin",
logs=sys.stdout,
name_expressions=("saxpy_kernel<float>", "saxpy_kernel<double>"),
)
# Run in single precision
kernel = mod.get_kernel("saxpy_kernel<float>")
dtype = torch.float32
# prepare input/output
size = 64
# Use a single element tensor for 'a'
a = torch.tensor([10.0], dtype=dtype, device="cuda")
x = torch.rand(size, dtype=dtype, device="cuda")
y = torch.rand(size, dtype=dtype, device="cuda")
out = torch.empty_like(x)
# prepare launch
block = 32
grid = int((size + block - 1) // block)
config = LaunchConfig(grid=grid, block=block)
kernel_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size)
# launch kernel on our stream
launch(stream, config, kernel, *kernel_args)
# check result
assert torch.allclose(out, a.item() * x + y)
# let's repeat again with double precision
kernel = mod.get_kernel("saxpy_kernel<double>")
dtype = torch.float64
# prepare input
size = 128
# Use a single element tensor for 'a'
a = torch.tensor([42.0], dtype=dtype, device="cuda")
x = torch.rand(size, dtype=dtype, device="cuda")
y = torch.rand(size, dtype=dtype, device="cuda")
# prepare output
out = torch.empty_like(x)
# prepare launch
block = 64
grid = int((size + block - 1) // block)
config = LaunchConfig(grid=grid, block=block)
kernel_args = (a.data_ptr(), x.data_ptr(), y.data_ptr(), out.data_ptr(), size)
# launch kernel on PyTorch's stream
launch(stream, config, kernel, *kernel_args)
# check result
assert torch.allclose(out, a * x + y)