@@ -10,13 +10,13 @@ namespace py = pybind11;
1010
1111// Forward declarations
1212void naive_attention_fp32 (const float *, const float *, const float *, float *,
13- int , int , int , int , float , cudaStream_t);
13+ int , int , int , int , float , bool , cudaStream_t);
1414void naive_attention_fp16 (const half*, const half*, const half*, half*,
15- int , int , int , int , float , cudaStream_t);
15+ int , int , int , int , float , bool , cudaStream_t);
1616void tiled_attention_fp32 (const float *, const float *, const float *, float *,
17- int , int , int , int , float , cudaStream_t);
17+ int , int , int , int , float , bool , cudaStream_t);
1818void tiled_attention_fp16 (const half*, const half*, const half*, half*,
19- int , int , int , int , float , cudaStream_t);
19+ int , int , int , int , float , bool , cudaStream_t);
2020void flash_attention_fp32 (const float *, const float *, const float *, float *,
2121 int , int , int , int , float , bool , cudaStream_t);
2222void flash_attention_fp16 (const half*, const half*, const half*, half*,
@@ -96,30 +96,31 @@ torch::Tensor naive_attention(
9696 const torch::Tensor& q,
9797 const torch::Tensor& k,
9898 const torch::Tensor& v,
99- float scale = 0 .0f
99+ float scale = 0 .0f ,
100+ bool is_causal = false
100101) {
101102 validate_attention_inputs (q, k, v);
102-
103+
103104 int batch_size = q.size (0 );
104105 int num_heads = q.size (1 );
105106 int seq_len = q.size (2 );
106107 int head_dim = q.size (3 );
107-
108+
108109 if (scale == 0 .0f ) {
109110 scale = 1 .0f / sqrtf (static_cast <float >(head_dim));
110111 }
111-
112+
112113 auto output = torch::empty_like (q);
113114 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
114-
115+
115116 if (q.scalar_type () == torch::kFloat32 ) {
116117 naive_attention_fp32 (
117118 q.data_ptr <float >(),
118119 k.data_ptr <float >(),
119120 v.data_ptr <float >(),
120121 output.data_ptr <float >(),
121122 batch_size, num_heads, seq_len, head_dim,
122- scale, stream
123+ scale, is_causal, stream
123124 );
124125 } else {
125126 naive_attention_fp16 (
@@ -128,10 +129,10 @@ torch::Tensor naive_attention(
128129 reinterpret_cast <const half*>(v.data_ptr <at::Half>()),
129130 reinterpret_cast <half*>(output.data_ptr <at::Half>()),
130131 batch_size, num_heads, seq_len, head_dim,
131- scale, stream
132+ scale, is_causal, stream
132133 );
133134 }
134-
135+
135136 return output;
136137}
137138
@@ -140,30 +141,31 @@ torch::Tensor tiled_attention(
140141 const torch::Tensor& q,
141142 const torch::Tensor& k,
142143 const torch::Tensor& v,
143- float scale = 0 .0f
144+ float scale = 0 .0f ,
145+ bool is_causal = false
144146) {
145147 validate_attention_inputs (q, k, v);
146-
148+
147149 int batch_size = q.size (0 );
148150 int num_heads = q.size (1 );
149151 int seq_len = q.size (2 );
150152 int head_dim = q.size (3 );
151-
153+
152154 if (scale == 0 .0f ) {
153155 scale = 1 .0f / sqrtf (static_cast <float >(head_dim));
154156 }
155-
157+
156158 auto output = torch::empty_like (q);
157159 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
158-
160+
159161 if (q.scalar_type () == torch::kFloat32 ) {
160162 tiled_attention_fp32 (
161163 q.data_ptr <float >(),
162164 k.data_ptr <float >(),
163165 v.data_ptr <float >(),
164166 output.data_ptr <float >(),
165167 batch_size, num_heads, seq_len, head_dim,
166- scale, stream
168+ scale, is_causal, stream
167169 );
168170 } else {
169171 tiled_attention_fp16 (
@@ -172,10 +174,10 @@ torch::Tensor tiled_attention(
172174 reinterpret_cast <const half*>(v.data_ptr <at::Half>()),
173175 reinterpret_cast <half*>(output.data_ptr <at::Half>()),
174176 batch_size, num_heads, seq_len, head_dim,
175- scale, stream
177+ scale, is_causal, stream
176178 );
177179 }
178-
180+
179181 return output;
180182}
181183
@@ -331,33 +333,35 @@ torch::Tensor tensor_core_gemm_int8_wrapper(
331333
332334PYBIND11_MODULE (cuda_llm_ops, m) {
333335 m.doc () = " CUDA LLM Kernel Optimization - High-performance attention and GEMM kernels" ;
334-
336+
335337 // Attention functions
336338 m.def (" naive_attention" , &naive_attention,
337- py::arg (" q" ), py::arg (" k" ), py::arg (" v" ), py::arg (" scale" ) = 0 .0f ,
339+ py::arg (" q" ), py::arg (" k" ), py::arg (" v" ),
340+ py::arg (" scale" ) = 0 .0f , py::arg (" is_causal" ) = false ,
338341 " Naive attention implementation (baseline)" );
339-
342+
340343 m.def (" tiled_attention" , &tiled_attention,
341- py::arg (" q" ), py::arg (" k" ), py::arg (" v" ), py::arg (" scale" ) = 0 .0f ,
344+ py::arg (" q" ), py::arg (" k" ), py::arg (" v" ),
345+ py::arg (" scale" ) = 0 .0f , py::arg (" is_causal" ) = false ,
342346 " Tiled attention with shared memory optimization" );
343-
347+
344348 m.def (" flash_attention" , &flash_attention,
345- py::arg (" q" ), py::arg (" k" ), py::arg (" v" ),
349+ py::arg (" q" ), py::arg (" k" ), py::arg (" v" ),
346350 py::arg (" scale" ) = 0 .0f , py::arg (" is_causal" ) = false ,
347351 " FlashAttention with online softmax" );
348-
352+
349353 // GEMM functions
350354 m.def (" gemm" , &gemm,
351355 py::arg (" a" ), py::arg (" b" ),
352356 py::arg (" alpha" ) = 1 .0f , py::arg (" beta" ) = 0 .0f ,
353357 py::arg (" trans_a" ) = false , py::arg (" trans_b" ) = false ,
354358 " High-performance GEMM with register tiling" );
355-
359+
356360 m.def (" tensor_core_gemm" , &tensor_core_gemm,
357361 py::arg (" a" ), py::arg (" b" ),
358362 py::arg (" alpha" ) = 1 .0f , py::arg (" beta" ) = 0 .0f ,
359363 " Tensor Core GEMM (FP16 input, FP32 output)" );
360-
364+
361365 m.def (" tensor_core_gemm_int8" , &tensor_core_gemm_int8_wrapper,
362366 py::arg (" a" ), py::arg (" b" ),
363367 " Tensor Core GEMM (INT8 input, INT32 output, requires Turing+ SM>=7.2)" );
0 commit comments