|
| 1 | +/************************************************************************* |
| 2 | + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | + * |
| 4 | + * See LICENSE for license information. |
| 5 | + ************************************************************************/ |
| 6 | + |
| 7 | +#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ |
| 8 | +#define TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ |
| 9 | + |
| 10 | +#include <cuda_pipeline.h> |
| 11 | + |
| 12 | +#include <type_traits> |
| 13 | + |
| 14 | +#include "../utils.cuh" |
| 15 | +#include "utils.h" |
| 16 | + |
| 17 | +namespace transformer_engine { |
| 18 | +namespace fused_router { |
| 19 | + |
| 20 | +// ============================================================================ |
| 21 | +// Persistent kernel grid size computation |
| 22 | +// ============================================================================ |
| 23 | + |
| 24 | +// Compute a persistent grid size: min(total_blocks_needed, SMs * max_blocks_per_SM). |
| 25 | +// `kernel_func` is a pointer to the __global__ function. |
| 26 | +// `block_size` is kThreadsPerBlock. |
| 27 | +// `shmem_bytes` is the dynamic shared memory per block. |
| 28 | +// `total_blocks` is ceil(num_tokens / tokens_per_block). |
| 29 | +template <typename KernelFunc> |
| 30 | +inline size_t compute_persistent_grid(KernelFunc kernel_func, int block_size, size_t shmem_bytes, |
| 31 | + size_t total_blocks) { |
| 32 | + int blocks_per_sm = 0; |
| 33 | + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel_func, |
| 34 | + block_size, shmem_bytes)); |
| 35 | + if (blocks_per_sm <= 0) { |
| 36 | + return total_blocks; |
| 37 | + } |
| 38 | + int device_id; |
| 39 | + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); |
| 40 | + int num_sms; |
| 41 | + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device_id)); |
| 42 | + |
| 43 | + size_t max_resident = static_cast<size_t>(num_sms) * blocks_per_sm; |
| 44 | + return (total_blocks < max_resident) ? total_blocks : max_resident; |
| 45 | +} |
| 46 | + |
| 47 | +// ============================================================================ |
| 48 | +// Occupancy-aware double-buffer decision |
| 49 | +// ============================================================================ |
| 50 | + |
| 51 | +// Decide whether to use 1 or 2 buffers based on shmem budget. |
| 52 | +// `single_buf_shmem` is the per-buffer shmem for the async-loaded data. |
| 53 | +// `other_shmem_bytes` is shmem for everything else (scratch, work buffers). |
| 54 | +// Returns 1 or 2. Ensures at least kMinBlocksPerSM blocks can co-reside. |
| 55 | +inline int choose_num_buffers(size_t single_buf_shmem, size_t other_shmem_bytes) { |
| 56 | + constexpr int kMinBlocksPerSM = 4; |
| 57 | + |
| 58 | + size_t total_single = single_buf_shmem + other_shmem_bytes; |
| 59 | + size_t total_double = 2 * single_buf_shmem + other_shmem_bytes; |
| 60 | + |
| 61 | + int device_id; |
| 62 | + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); |
| 63 | + int max_smem_per_sm; |
| 64 | + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&max_smem_per_sm, |
| 65 | + cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)); |
| 66 | + |
| 67 | + int blocks_double = (total_double > 0) ? static_cast<int>(max_smem_per_sm / total_double) : 0; |
| 68 | + int blocks_single = (total_single > 0) ? static_cast<int>(max_smem_per_sm / total_single) : 0; |
| 69 | + |
| 70 | + if (blocks_double >= kMinBlocksPerSM) return 2; |
| 71 | + if (blocks_single >= kMinBlocksPerSM) return 1; |
| 72 | + // Neither option meets the minimum; prefer single buffer for occupancy |
| 73 | + // (total_double >= total_single, so blocks_single >= blocks_double always). |
| 74 | + return 1; |
| 75 | +} |
| 76 | + |
| 77 | +// ============================================================================ |
| 78 | +// Vectorized global store/fill helpers (using Vec<> from utils.cuh) |
| 79 | +// ============================================================================ |
| 80 | + |
| 81 | +template <typename T> |
| 82 | +struct VecTraits { |
| 83 | + static constexpr int kVecSize = (sizeof(T) <= 16) ? (16 / sizeof(T)) : 1; |
| 84 | +}; |
| 85 | + |
| 86 | +// Vectorized store: write `count` elements from shmem/registers to global memory. |
| 87 | +template <typename T> |
| 88 | +__device__ inline void vec_store_global(T *__restrict__ dst, const T *__restrict__ src, int count, |
| 89 | + int lane_id) { |
| 90 | + constexpr int kVecSize = VecTraits<T>::kVecSize; |
| 91 | + using VecType = typename BytesToType<sizeof(T) * kVecSize>::Type; |
| 92 | + |
| 93 | + bool aligned = (reinterpret_cast<uint64_t>(dst) % (sizeof(T) * kVecSize) == 0); |
| 94 | + int aligned_count = (count / kVecSize) * kVecSize; |
| 95 | + |
| 96 | + if (aligned && aligned_count > 0) { |
| 97 | + int vec_count = aligned_count / kVecSize; |
| 98 | + for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { |
| 99 | + VecType v; |
| 100 | + T *v_elts = reinterpret_cast<T *>(&v); |
| 101 | +#pragma unroll |
| 102 | + for (int e = 0; e < kVecSize; e++) { |
| 103 | + v_elts[e] = src[vi * kVecSize + e]; |
| 104 | + } |
| 105 | + reinterpret_cast<VecType *>(dst)[vi] = v; |
| 106 | + } |
| 107 | + for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { |
| 108 | + dst[i] = src[i]; |
| 109 | + } |
| 110 | + } else { |
| 111 | + for (int i = lane_id; i < count; i += kThreadsPerWarp) { |
| 112 | + dst[i] = src[i]; |
| 113 | + } |
| 114 | + } |
| 115 | +} |
| 116 | + |
| 117 | +// Vectorized fill: write `val` to `count` elements of global memory. |
| 118 | +template <typename T> |
| 119 | +__device__ inline void vec_fill_global(T *__restrict__ dst, T val, int count, int lane_id) { |
| 120 | + constexpr int kVecSize = VecTraits<T>::kVecSize; |
| 121 | + using VecType = typename BytesToType<sizeof(T) * kVecSize>::Type; |
| 122 | + |
| 123 | + bool aligned = (reinterpret_cast<uint64_t>(dst) % (sizeof(T) * kVecSize) == 0); |
| 124 | + int aligned_count = (count / kVecSize) * kVecSize; |
| 125 | + |
| 126 | + if (aligned && aligned_count > 0) { |
| 127 | + VecType v; |
| 128 | + T *v_elts = reinterpret_cast<T *>(&v); |
| 129 | +#pragma unroll |
| 130 | + for (int e = 0; e < kVecSize; e++) { |
| 131 | + v_elts[e] = val; |
| 132 | + } |
| 133 | + int vec_count = aligned_count / kVecSize; |
| 134 | + for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { |
| 135 | + reinterpret_cast<VecType *>(dst)[vi] = v; |
| 136 | + } |
| 137 | + for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { |
| 138 | + dst[i] = val; |
| 139 | + } |
| 140 | + } else { |
| 141 | + for (int i = lane_id; i < count; i += kThreadsPerWarp) { |
| 142 | + dst[i] = val; |
| 143 | + } |
| 144 | + } |
| 145 | +} |
| 146 | + |
| 147 | +// ============================================================================ |
| 148 | +// cp.async wrappers — use hardware async copy on sm_80+, no-op on older archs. |
| 149 | +// Always defined so callers don't need #if guards. |
| 150 | +// ============================================================================ |
| 151 | + |
| 152 | +__device__ __forceinline__ void cp_async_16B(void *__restrict__ dst, const void *__restrict__ src) { |
| 153 | +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
| 154 | + __pipeline_memcpy_async(dst, src, 16); |
| 155 | +#else |
| 156 | + // Scalar fallback — callers must not rely on this being async. |
| 157 | + *static_cast<int4 *>(dst) = *static_cast<const int4 *>(src); |
| 158 | +#endif |
| 159 | +} |
| 160 | + |
| 161 | +__device__ __forceinline__ void cp_async_commit() { |
| 162 | +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
| 163 | + __pipeline_commit(); |
| 164 | +#endif |
| 165 | +} |
| 166 | + |
| 167 | +__device__ __forceinline__ void cp_async_wait_all() { |
| 168 | +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
| 169 | + __pipeline_wait_prior(0); |
| 170 | +#endif |
| 171 | +} |
| 172 | + |
| 173 | +// ============================================================================ |
| 174 | +// RawAsyncLoader<T> — double-buffered loader storing data in original type |
| 175 | +// |
| 176 | +// Enables cp.async for ALL data types (bf16, fp16, fp32) since no type |
| 177 | +// conversion is needed during the copy. The kernel reads from shmem and |
| 178 | +// casts to CompType during compute. |
| 179 | +// ============================================================================ |
| 180 | + |
| 181 | +template <typename T> |
| 182 | +class RawAsyncLoader { |
| 183 | + public: |
| 184 | + // Shmem size calculation (usable from both host and device). |
| 185 | + static __host__ __device__ inline size_t shmem_bytes(int count, int num_warps, int num_buffers) { |
| 186 | + return static_cast<size_t>(num_buffers) * count * num_warps * sizeof(T); |
| 187 | + } |
| 188 | + |
| 189 | + // Device-side construction. |
| 190 | + __device__ RawAsyncLoader(T *buf_base, int warp_id, int count, int num_warps, int num_buffers) |
| 191 | + : phase_(0), double_buf_(num_buffers == 2) { |
| 192 | + int per_buffer = count * num_warps; |
| 193 | + buf_[0] = buf_base + warp_id * count; |
| 194 | + buf_[1] = (num_buffers == 2) ? buf_base + per_buffer + warp_id * count : buf_[0]; |
| 195 | + } |
| 196 | + |
| 197 | + __device__ __forceinline__ T *current_buf() { return buf_[phase_]; } |
| 198 | + __device__ __forceinline__ T *next_buf() { return buf_[phase_ ^ 1]; } |
| 199 | + __device__ __forceinline__ void flip() { |
| 200 | + if (double_buf_) phase_ ^= 1; |
| 201 | + } |
| 202 | + |
| 203 | + // Async load into the NEXT buffer (for prefetching). |
| 204 | + __device__ void start_load(const T *__restrict__ src, int count, int lane_id) { |
| 205 | + raw_load(src, next_buf(), count, lane_id); |
| 206 | + } |
| 207 | + |
| 208 | + // Load into the CURRENT buffer (for the first load before the main loop). |
| 209 | + __device__ void load_current(const T *__restrict__ src, int count, int lane_id) { |
| 210 | + raw_load(src, current_buf(), count, lane_id); |
| 211 | + } |
| 212 | + |
| 213 | + // Wait for pending async loads to complete. |
| 214 | + __device__ __forceinline__ void wait() { |
| 215 | + cp_async_wait_all(); |
| 216 | + __syncwarp(); |
| 217 | + } |
| 218 | + |
| 219 | + private: |
| 220 | + T *buf_[2]; |
| 221 | + int phase_; |
| 222 | + bool double_buf_; |
| 223 | + |
| 224 | + // Raw copy: global → shmem, no type conversion. |
| 225 | + // Uses 16-byte vectorised copies (cp.async on sm_80+, int4 on older archs) |
| 226 | + // when both pointers are 16-byte aligned, with a scalar tail / fallback. |
| 227 | + __device__ void raw_load(const T *__restrict__ src, T *__restrict__ dst, int count, int lane_id) { |
| 228 | + constexpr int kBytesPerCopy = 16; |
| 229 | + constexpr int kEltsPerCopy = kBytesPerCopy / sizeof(T); |
| 230 | + |
| 231 | + bool src_aligned = (reinterpret_cast<uint64_t>(src) % kBytesPerCopy == 0); |
| 232 | + bool dst_aligned = (reinterpret_cast<uint64_t>(dst) % kBytesPerCopy == 0); |
| 233 | + int aligned_count = (count / kEltsPerCopy) * kEltsPerCopy; |
| 234 | + |
| 235 | + if (src_aligned && dst_aligned && aligned_count > 0) { |
| 236 | + int vec_count = aligned_count / kEltsPerCopy; |
| 237 | + for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { |
| 238 | + cp_async_16B(dst + vi * kEltsPerCopy, src + vi * kEltsPerCopy); |
| 239 | + } |
| 240 | + for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { |
| 241 | + dst[i] = src[i]; |
| 242 | + } |
| 243 | + cp_async_commit(); |
| 244 | + } else { |
| 245 | + for (int i = lane_id; i < count; i += kThreadsPerWarp) { |
| 246 | + dst[i] = src[i]; |
| 247 | + } |
| 248 | + } |
| 249 | + } |
| 250 | +}; |
| 251 | + |
| 252 | +} // namespace fused_router |
| 253 | +} // namespace transformer_engine |
| 254 | + |
| 255 | +#endif // TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ |
0 commit comments