|
1 | 1 | // HPC-AI-Optimization-Lab Python Bindings |
2 | | -// Using Nanobind for zero-copy PyTorch tensor integration |
| 2 | +// Thin nanobind wrappers over selected CUDA kernels. |
3 | 3 |
|
4 | 4 | #include <nanobind/nanobind.h> |
5 | 5 | #include <nanobind/tensor.h> |
| 6 | + |
| 7 | +#include <cmath> |
| 8 | +#include <stdexcept> |
| 9 | +#include <string> |
| 10 | + |
6 | 11 | #include <cuda_runtime.h> |
7 | 12 |
|
8 | | -// Include kernel headers |
9 | 13 | #include "01_elementwise/relu.cuh" |
10 | 14 | #include "01_elementwise/sigmoid.cuh" |
11 | 15 | #include "01_elementwise/transpose.cuh" |
12 | | -#include "01_elementwise/vector_add.cuh" |
13 | | -#include "02_reduction/softmax.cuh" |
14 | 16 | #include "02_reduction/layernorm.cuh" |
15 | 17 | #include "02_reduction/rmsnorm.cuh" |
| 18 | +#include "02_reduction/softmax.cuh" |
16 | 19 | #include "03_gemm/gemm.cuh" |
17 | | -#include "05_attention/flash_attention.cuh" |
18 | | -#include "05_attention/rope.cuh" |
19 | 20 |
|
20 | 21 | namespace nb = nanobind; |
21 | 22 |
|
22 | | -// Helper to get CUDA pointer from PyTorch tensor |
23 | | -template<typename T> |
24 | | -T* get_cuda_ptr(nb::tensor<T, nb::device::cuda>& tensor) { |
25 | | - return tensor.data(); |
| 23 | +namespace { |
| 24 | + |
| 25 | +template <typename T> |
| 26 | +size_t require_non_empty(const nb::tensor<T, nb::device::cuda>& tensor, const char* name) { |
| 27 | + const size_t size = tensor.size(); |
| 28 | + if (size == 0) { |
| 29 | + throw std::invalid_argument(std::string(name) + " must not be empty"); |
| 30 | + } |
| 31 | + return size; |
| 32 | +} |
| 33 | + |
| 34 | +inline void require_size(size_t actual, size_t expected, const char* name) { |
| 35 | + if (actual != expected) { |
| 36 | + throw std::invalid_argument( |
| 37 | + std::string(name) + " has unexpected size: expected " + |
| 38 | + std::to_string(expected) + ", got " + std::to_string(actual)); |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | +inline size_t require_positive_product(int lhs, int rhs, const char* lhs_name, const char* rhs_name) { |
| 43 | + if (lhs <= 0 || rhs <= 0) { |
| 44 | + throw std::invalid_argument(std::string(lhs_name) + " and " + rhs_name + " must be positive"); |
| 45 | + } |
| 46 | + return static_cast<size_t>(lhs) * static_cast<size_t>(rhs); |
26 | 47 | } |
27 | 48 |
|
28 | | -// Elementwise operations |
| 49 | +inline void require_finite_positive(float value, const char* name) { |
| 50 | + if (!std::isfinite(value) || value <= 0.0f) { |
| 51 | + throw std::invalid_argument(std::string(name) + " must be finite and positive"); |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +inline void require_finite(float value, const char* name) { |
| 56 | + if (!std::isfinite(value)) { |
| 57 | + throw std::invalid_argument(std::string(name) + " must be finite"); |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +} // namespace |
| 62 | + |
29 | 63 | void relu_wrapper(nb::tensor<float, nb::device::cuda>& input, |
30 | 64 | nb::tensor<float, nb::device::cuda>& output) { |
31 | | - size_t n = input.size(); |
| 65 | + const size_t n = require_non_empty(input, "input"); |
| 66 | + require_size(output.size(), n, "output"); |
32 | 67 | hpc::elementwise::relu<float, hpc::elementwise::OptLevel::GridStride>( |
33 | 68 | input.data(), output.data(), n, nullptr); |
34 | | - cudaDeviceSynchronize(); |
35 | 69 | } |
36 | 70 |
|
37 | 71 | void sigmoid_wrapper(nb::tensor<float, nb::device::cuda>& input, |
38 | 72 | nb::tensor<float, nb::device::cuda>& output) { |
39 | | - size_t n = input.size(); |
| 73 | + const size_t n = require_non_empty(input, "input"); |
| 74 | + require_size(output.size(), n, "output"); |
40 | 75 | hpc::elementwise::sigmoid<float, hpc::elementwise::OptLevel::GridStride>( |
41 | 76 | input.data(), output.data(), n, nullptr); |
42 | | - cudaDeviceSynchronize(); |
43 | 77 | } |
44 | 78 |
|
45 | 79 | void transpose_wrapper(nb::tensor<float, nb::device::cuda>& input, |
46 | 80 | nb::tensor<float, nb::device::cuda>& output, |
47 | 81 | int rows, int cols) { |
| 82 | + const size_t expected = require_positive_product(rows, cols, "rows", "cols"); |
| 83 | + require_size(input.size(), expected, "input"); |
| 84 | + require_size(output.size(), expected, "output"); |
48 | 85 | hpc::elementwise::transpose<float, hpc::elementwise::TransposeOpt::SharedMemPadded>( |
49 | 86 | input.data(), output.data(), rows, cols, nullptr); |
50 | | - cudaDeviceSynchronize(); |
51 | 87 | } |
52 | 88 |
|
53 | | -// Reduction operations |
54 | 89 | void softmax_wrapper(nb::tensor<float, nb::device::cuda>& input, |
55 | 90 | nb::tensor<float, nb::device::cuda>& output, |
56 | 91 | int batch, int seq_len) { |
| 92 | + const size_t expected = require_positive_product(batch, seq_len, "batch", "seq_len"); |
| 93 | + require_size(input.size(), expected, "input"); |
| 94 | + require_size(output.size(), expected, "output"); |
57 | 95 | hpc::reduction::softmax<float, hpc::reduction::SoftmaxOpt::OnlineSoftmax>( |
58 | 96 | input.data(), output.data(), batch, seq_len, nullptr); |
59 | | - cudaDeviceSynchronize(); |
60 | 97 | } |
61 | 98 |
|
62 | 99 | void layer_norm_wrapper(nb::tensor<float, nb::device::cuda>& input, |
63 | 100 | nb::tensor<float, nb::device::cuda>& gamma, |
64 | 101 | nb::tensor<float, nb::device::cuda>& beta, |
65 | 102 | nb::tensor<float, nb::device::cuda>& output, |
66 | 103 | int batch, int hidden_size, float eps) { |
| 104 | + const size_t expected = require_positive_product(batch, hidden_size, "batch", "hidden_size"); |
| 105 | + require_finite_positive(eps, "eps"); |
| 106 | + require_size(input.size(), expected, "input"); |
| 107 | + require_size(output.size(), expected, "output"); |
| 108 | + require_size(gamma.size(), static_cast<size_t>(hidden_size), "gamma"); |
| 109 | + require_size(beta.size(), static_cast<size_t>(hidden_size), "beta"); |
67 | 110 | hpc::reduction::layer_norm<float>( |
68 | 111 | input.data(), gamma.data(), beta.data(), output.data(), |
69 | 112 | batch, hidden_size, eps, nullptr); |
70 | | - cudaDeviceSynchronize(); |
71 | 113 | } |
72 | 114 |
|
73 | 115 | void rms_norm_wrapper(nb::tensor<float, nb::device::cuda>& input, |
74 | 116 | nb::tensor<float, nb::device::cuda>& gamma, |
75 | 117 | nb::tensor<float, nb::device::cuda>& output, |
76 | 118 | int batch, int hidden_size, float eps) { |
| 119 | + const size_t expected = require_positive_product(batch, hidden_size, "batch", "hidden_size"); |
| 120 | + require_finite_positive(eps, "eps"); |
| 121 | + require_size(input.size(), expected, "input"); |
| 122 | + require_size(output.size(), expected, "output"); |
| 123 | + require_size(gamma.size(), static_cast<size_t>(hidden_size), "gamma"); |
77 | 124 | hpc::reduction::rms_norm<float>( |
78 | 125 | input.data(), gamma.data(), output.data(), |
79 | 126 | batch, hidden_size, eps, nullptr); |
80 | | - cudaDeviceSynchronize(); |
81 | 127 | } |
82 | 128 |
|
83 | | -// GEMM |
84 | 129 | void matmul_wrapper(nb::tensor<float, nb::device::cuda>& A, |
85 | 130 | nb::tensor<float, nb::device::cuda>& B, |
86 | 131 | nb::tensor<float, nb::device::cuda>& C, |
87 | 132 | int M, int N, int K, |
88 | 133 | float alpha, float beta) { |
| 134 | + const size_t a_expected = require_positive_product(M, K, "M", "K"); |
| 135 | + const size_t b_expected = require_positive_product(K, N, "K", "N"); |
| 136 | + const size_t c_expected = require_positive_product(M, N, "M", "N"); |
| 137 | + require_finite(alpha, "alpha"); |
| 138 | + require_finite(beta, "beta"); |
| 139 | + require_size(A.size(), a_expected, "A"); |
| 140 | + require_size(B.size(), b_expected, "B"); |
| 141 | + require_size(C.size(), c_expected, "C"); |
89 | 142 | hpc::gemm::gemm<float, hpc::gemm::GemmOpt::SharedMemTiling>( |
90 | 143 | A.data(), B.data(), C.data(), M, N, K, alpha, beta, nullptr); |
91 | | - cudaDeviceSynchronize(); |
92 | 144 | } |
93 | 145 |
|
94 | 146 | NB_MODULE(hpc_ai_opt, m) { |
95 | 147 | m.doc() = "HPC-AI-Optimization-Lab CUDA Kernels"; |
96 | | - |
97 | | - // Elementwise submodule |
| 148 | + |
98 | 149 | auto elementwise = m.def_submodule("elementwise", "Elementwise operations"); |
99 | 150 | elementwise.def("relu", &relu_wrapper, "ReLU activation"); |
100 | 151 | elementwise.def("sigmoid", &sigmoid_wrapper, "Sigmoid activation"); |
101 | 152 | elementwise.def("transpose", &transpose_wrapper, "Matrix transpose"); |
102 | | - |
103 | | - // Reduction submodule |
| 153 | + |
104 | 154 | auto reduction = m.def_submodule("reduction", "Reduction operations"); |
105 | 155 | reduction.def("softmax", &softmax_wrapper, "Softmax"); |
106 | 156 | reduction.def("layer_norm", &layer_norm_wrapper, "Layer normalization"); |
107 | 157 | reduction.def("rms_norm", &rms_norm_wrapper, "RMS normalization"); |
108 | | - |
109 | | - // GEMM submodule |
| 158 | + |
110 | 159 | auto gemm = m.def_submodule("gemm", "Matrix multiplication"); |
111 | 160 | gemm.def("matmul", &matmul_wrapper, "Matrix multiplication"); |
112 | 161 | } |
0 commit comments