Skip to content

Commit 6a5cb00

Browse files
author
shijiashuai
committed
fix: validate Python binding inputs
Check tensor sizes and numeric arguments in the nanobind wrappers and stop forcing device-wide synchronization so Python calls fail fast without serializing unrelated CUDA work.
1 parent 1ac5630 commit 6a5cb00

1 file changed

Lines changed: 77 additions & 28 deletions

File tree

python/bindings/bindings.cpp

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,161 @@
11
// HPC-AI-Optimization-Lab Python Bindings
2-
// Using Nanobind for zero-copy PyTorch tensor integration
2+
// Thin nanobind wrappers over selected CUDA kernels.
33

44
#include <nanobind/nanobind.h>
55
#include <nanobind/tensor.h>
6+
7+
#include <cmath>
8+
#include <stdexcept>
9+
#include <string>
10+
611
#include <cuda_runtime.h>
712

8-
// Include kernel headers
913
#include "01_elementwise/relu.cuh"
1014
#include "01_elementwise/sigmoid.cuh"
1115
#include "01_elementwise/transpose.cuh"
12-
#include "01_elementwise/vector_add.cuh"
13-
#include "02_reduction/softmax.cuh"
1416
#include "02_reduction/layernorm.cuh"
1517
#include "02_reduction/rmsnorm.cuh"
18+
#include "02_reduction/softmax.cuh"
1619
#include "03_gemm/gemm.cuh"
17-
#include "05_attention/flash_attention.cuh"
18-
#include "05_attention/rope.cuh"
1920

2021
namespace nb = nanobind;
2122

22-
// Helper to get CUDA pointer from PyTorch tensor
23-
template<typename T>
24-
T* get_cuda_ptr(nb::tensor<T, nb::device::cuda>& tensor) {
25-
return tensor.data();
23+
namespace {
24+
25+
template <typename T>
26+
size_t require_non_empty(const nb::tensor<T, nb::device::cuda>& tensor, const char* name) {
27+
const size_t size = tensor.size();
28+
if (size == 0) {
29+
throw std::invalid_argument(std::string(name) + " must not be empty");
30+
}
31+
return size;
32+
}
33+
34+
inline void require_size(size_t actual, size_t expected, const char* name) {
35+
if (actual != expected) {
36+
throw std::invalid_argument(
37+
std::string(name) + " has unexpected size: expected " +
38+
std::to_string(expected) + ", got " + std::to_string(actual));
39+
}
40+
}
41+
42+
inline size_t require_positive_product(int lhs, int rhs, const char* lhs_name, const char* rhs_name) {
43+
if (lhs <= 0 || rhs <= 0) {
44+
throw std::invalid_argument(std::string(lhs_name) + " and " + rhs_name + " must be positive");
45+
}
46+
return static_cast<size_t>(lhs) * static_cast<size_t>(rhs);
2647
}
2748

28-
// Elementwise operations
49+
inline void require_finite_positive(float value, const char* name) {
50+
if (!std::isfinite(value) || value <= 0.0f) {
51+
throw std::invalid_argument(std::string(name) + " must be finite and positive");
52+
}
53+
}
54+
55+
inline void require_finite(float value, const char* name) {
56+
if (!std::isfinite(value)) {
57+
throw std::invalid_argument(std::string(name) + " must be finite");
58+
}
59+
}
60+
61+
} // namespace
62+
2963
void relu_wrapper(nb::tensor<float, nb::device::cuda>& input,
3064
nb::tensor<float, nb::device::cuda>& output) {
31-
size_t n = input.size();
65+
const size_t n = require_non_empty(input, "input");
66+
require_size(output.size(), n, "output");
3267
hpc::elementwise::relu<float, hpc::elementwise::OptLevel::GridStride>(
3368
input.data(), output.data(), n, nullptr);
34-
cudaDeviceSynchronize();
3569
}
3670

3771
void sigmoid_wrapper(nb::tensor<float, nb::device::cuda>& input,
3872
nb::tensor<float, nb::device::cuda>& output) {
39-
size_t n = input.size();
73+
const size_t n = require_non_empty(input, "input");
74+
require_size(output.size(), n, "output");
4075
hpc::elementwise::sigmoid<float, hpc::elementwise::OptLevel::GridStride>(
4176
input.data(), output.data(), n, nullptr);
42-
cudaDeviceSynchronize();
4377
}
4478

4579
void transpose_wrapper(nb::tensor<float, nb::device::cuda>& input,
4680
nb::tensor<float, nb::device::cuda>& output,
4781
int rows, int cols) {
82+
const size_t expected = require_positive_product(rows, cols, "rows", "cols");
83+
require_size(input.size(), expected, "input");
84+
require_size(output.size(), expected, "output");
4885
hpc::elementwise::transpose<float, hpc::elementwise::TransposeOpt::SharedMemPadded>(
4986
input.data(), output.data(), rows, cols, nullptr);
50-
cudaDeviceSynchronize();
5187
}
5288

53-
// Reduction operations
5489
void softmax_wrapper(nb::tensor<float, nb::device::cuda>& input,
5590
nb::tensor<float, nb::device::cuda>& output,
5691
int batch, int seq_len) {
92+
const size_t expected = require_positive_product(batch, seq_len, "batch", "seq_len");
93+
require_size(input.size(), expected, "input");
94+
require_size(output.size(), expected, "output");
5795
hpc::reduction::softmax<float, hpc::reduction::SoftmaxOpt::OnlineSoftmax>(
5896
input.data(), output.data(), batch, seq_len, nullptr);
59-
cudaDeviceSynchronize();
6097
}
6198

6299
void layer_norm_wrapper(nb::tensor<float, nb::device::cuda>& input,
63100
nb::tensor<float, nb::device::cuda>& gamma,
64101
nb::tensor<float, nb::device::cuda>& beta,
65102
nb::tensor<float, nb::device::cuda>& output,
66103
int batch, int hidden_size, float eps) {
104+
const size_t expected = require_positive_product(batch, hidden_size, "batch", "hidden_size");
105+
require_finite_positive(eps, "eps");
106+
require_size(input.size(), expected, "input");
107+
require_size(output.size(), expected, "output");
108+
require_size(gamma.size(), static_cast<size_t>(hidden_size), "gamma");
109+
require_size(beta.size(), static_cast<size_t>(hidden_size), "beta");
67110
hpc::reduction::layer_norm<float>(
68111
input.data(), gamma.data(), beta.data(), output.data(),
69112
batch, hidden_size, eps, nullptr);
70-
cudaDeviceSynchronize();
71113
}
72114

73115
void rms_norm_wrapper(nb::tensor<float, nb::device::cuda>& input,
74116
nb::tensor<float, nb::device::cuda>& gamma,
75117
nb::tensor<float, nb::device::cuda>& output,
76118
int batch, int hidden_size, float eps) {
119+
const size_t expected = require_positive_product(batch, hidden_size, "batch", "hidden_size");
120+
require_finite_positive(eps, "eps");
121+
require_size(input.size(), expected, "input");
122+
require_size(output.size(), expected, "output");
123+
require_size(gamma.size(), static_cast<size_t>(hidden_size), "gamma");
77124
hpc::reduction::rms_norm<float>(
78125
input.data(), gamma.data(), output.data(),
79126
batch, hidden_size, eps, nullptr);
80-
cudaDeviceSynchronize();
81127
}
82128

83-
// GEMM
84129
void matmul_wrapper(nb::tensor<float, nb::device::cuda>& A,
85130
nb::tensor<float, nb::device::cuda>& B,
86131
nb::tensor<float, nb::device::cuda>& C,
87132
int M, int N, int K,
88133
float alpha, float beta) {
134+
const size_t a_expected = require_positive_product(M, K, "M", "K");
135+
const size_t b_expected = require_positive_product(K, N, "K", "N");
136+
const size_t c_expected = require_positive_product(M, N, "M", "N");
137+
require_finite(alpha, "alpha");
138+
require_finite(beta, "beta");
139+
require_size(A.size(), a_expected, "A");
140+
require_size(B.size(), b_expected, "B");
141+
require_size(C.size(), c_expected, "C");
89142
hpc::gemm::gemm<float, hpc::gemm::GemmOpt::SharedMemTiling>(
90143
A.data(), B.data(), C.data(), M, N, K, alpha, beta, nullptr);
91-
cudaDeviceSynchronize();
92144
}
93145

94146
NB_MODULE(hpc_ai_opt, m) {
95147
m.doc() = "HPC-AI-Optimization-Lab CUDA Kernels";
96-
97-
// Elementwise submodule
148+
98149
auto elementwise = m.def_submodule("elementwise", "Elementwise operations");
99150
elementwise.def("relu", &relu_wrapper, "ReLU activation");
100151
elementwise.def("sigmoid", &sigmoid_wrapper, "Sigmoid activation");
101152
elementwise.def("transpose", &transpose_wrapper, "Matrix transpose");
102-
103-
// Reduction submodule
153+
104154
auto reduction = m.def_submodule("reduction", "Reduction operations");
105155
reduction.def("softmax", &softmax_wrapper, "Softmax");
106156
reduction.def("layer_norm", &layer_norm_wrapper, "Layer normalization");
107157
reduction.def("rms_norm", &rms_norm_wrapper, "RMS normalization");
108-
109-
// GEMM submodule
158+
110159
auto gemm = m.def_submodule("gemm", "Matrix multiplication");
111160
gemm.def("matmul", &matmul_wrapper, "Matrix multiplication");
112161
}

0 commit comments

Comments
 (0)