You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importcuda_llm_ops# List all available functionsdir(cuda_llm_ops)
# ['flash_attention', 'tiled_attention', 'naive_attention',# 'gemm', 'tensor_core_gemm', 'tensor_core_gemm_int8', '__version__']
Attention Functions
flash_attention
FlashAttention with O(N) memory complexity using online softmax algorithm.
Value tensor, shape [batch, heads, seq_len, head_dim]
scale
float
0.0
Attention scale factor. If 0.0, uses 1/√head_dim
is_causal
bool
False
Enable causal mask for autoregressive models
Returns
Same as flash_attention.
Warning
Memory Alert: This implementation stores the full N×N attention matrix. For long sequences (N > 1024), this may cause out-of-memory errors. Use flash_attention for production workloads.
Examples
fromcuda_llm_opsimportnaive_attention# Only recommended for short sequences or testingq=torch.randn(2, 4, 64, 32, device='cuda', dtype=torch.float16)
k=torch.randn_like(q)
v=torch.randn_like(q)
output=naive_attention(q, k, v)
# Verify correctness against PyTorch referencereference=torch.nn.functional.scaled_dot_product_attention(q, k, v)
asserttorch.allclose(output, reference, rtol=1e-3, atol=1e-3)
GEMM Functions
gemm
High-performance general matrix multiplication with register tiling.
# Example: Wrong dimensionsq=torch.randn(64, 32, device='cuda') # 2D instead of 4Dflash_attention(q, k, v)
# RuntimeError: Q must be 4D tensor [batch, heads, seq_len, head_dim]# Example: CPU tensorq=torch.randn(2, 4, 64, 32) # CPU tensorflash_attention(q, k, v)
# RuntimeError: Q must be on CUDA device# Example: Non-contiguous tensorq=torch.randn(2, 4, 64, 32, device='cuda').transpose(1, 2)
flash_attention(q, k, v)
# RuntimeError: Q must be contiguous# Example: Shape mismatchq=torch.randn(2, 4, 64, 32, device='cuda')
v=torch.randn(2, 4, 128, 32, device='cuda') # Different seq_lenflash_attention(q, k, v)
# RuntimeError: K and V must have same shape# Example: Unsupported dtypeq=torch.randn(2, 4, 64, 32, device='cuda', dtype=torch.int32)
flash_attention(q, k, v)
# RuntimeError: Only float32 and float16 are supported
Error Handling Pattern
importtorchfromcuda_llm_opsimportflash_attentiondefsafe_flash_attention(q, k, v, **kwargs):
try:
returnflash_attention(q, k, v, **kwargs)
exceptRuntimeErrorase:
error_msg=str(e)
if"must be on CUDA device"inerror_msg:
print("Error: Please move tensors to CUDA using .cuda()")
elif"must be 4D tensor"inerror_msg:
print("Error: Input shapes must be [batch, heads, seq_len, head_dim]")
elif"must have same shape"inerror_msg:
print("Error: Q, K, V tensors must have identical shapes")
else:
print(f"Error: {error_msg}")
raise
Performance Tips
Memory Optimization
# Use FlashAttention for long sequencesseq_len=1024ifseq_len>=512:
output=flash_attention(q, k, v) # O(N) memoryelse:
output=naive_attention(q, k, v) # May be faster for short sequences
Precision Selection
# FP16 for inference (recommended)q_fp16=q.half()
output=flash_attention(q_fp16, k_fp16, v_fp16)
# FP32 for training or when precision is criticaloutput=flash_attention(q.float(), k.float(), v.float())
# Tensor Core GEMM for FP16 inputs with FP32 accumulationc=tensor_core_gemm(a.half(), b.half()) # Returns FP32
Optimal Dimensions
# For best Tensor Core performance, use multiples of 16defround_up_to_16(x):
return ((x+15) //16) *16M=round_up_to_16(1000) # 1008N=round_up_to_16(500) # 512
Batch Processing
# Process multiple sequences together for better GPU utilizationbatch_size=8# Adjust based on available memoryq=torch.randn(batch_size, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)