Skip to content

Commit 166a496

Browse files
authored
Merge branch 'main' into fix-unfused-padding-causal-sdpa
2 parents 82f0e0e + 21ba49c commit 166a496

5 files changed

Lines changed: 1423 additions & 617 deletions

File tree

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#ifndef TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_
8+
#define TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_
9+
10+
#include <cuda_pipeline.h>
11+
12+
#include <type_traits>
13+
14+
#include "../utils.cuh"
15+
#include "utils.h"
16+
17+
namespace transformer_engine {
18+
namespace fused_router {
19+
20+
// ============================================================================
21+
// Persistent kernel grid size computation
22+
// ============================================================================
23+
24+
// Compute a persistent grid size: min(total_blocks_needed, SMs * max_blocks_per_SM).
25+
// `kernel_func` is a pointer to the __global__ function.
26+
// `block_size` is kThreadsPerBlock.
27+
// `shmem_bytes` is the dynamic shared memory per block.
28+
// `total_blocks` is ceil(num_tokens / tokens_per_block).
29+
template <typename KernelFunc>
30+
inline size_t compute_persistent_grid(KernelFunc kernel_func, int block_size, size_t shmem_bytes,
31+
size_t total_blocks) {
32+
int blocks_per_sm = 0;
33+
NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel_func,
34+
block_size, shmem_bytes));
35+
if (blocks_per_sm <= 0) {
36+
return total_blocks;
37+
}
38+
int device_id;
39+
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
40+
int num_sms;
41+
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device_id));
42+
43+
size_t max_resident = static_cast<size_t>(num_sms) * blocks_per_sm;
44+
return (total_blocks < max_resident) ? total_blocks : max_resident;
45+
}
46+
47+
// ============================================================================
48+
// Occupancy-aware double-buffer decision
49+
// ============================================================================
50+
51+
// Decide whether to use 1 or 2 buffers based on shmem budget.
52+
// `single_buf_shmem` is the per-buffer shmem for the async-loaded data.
53+
// `other_shmem_bytes` is shmem for everything else (scratch, work buffers).
54+
// Returns 1 or 2. Ensures at least kMinBlocksPerSM blocks can co-reside.
55+
inline int choose_num_buffers(size_t single_buf_shmem, size_t other_shmem_bytes) {
56+
constexpr int kMinBlocksPerSM = 4;
57+
58+
size_t total_single = single_buf_shmem + other_shmem_bytes;
59+
size_t total_double = 2 * single_buf_shmem + other_shmem_bytes;
60+
61+
int device_id;
62+
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
63+
int max_smem_per_sm;
64+
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&max_smem_per_sm,
65+
cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id));
66+
67+
int blocks_double = (total_double > 0) ? static_cast<int>(max_smem_per_sm / total_double) : 0;
68+
int blocks_single = (total_single > 0) ? static_cast<int>(max_smem_per_sm / total_single) : 0;
69+
70+
if (blocks_double >= kMinBlocksPerSM) return 2;
71+
if (blocks_single >= kMinBlocksPerSM) return 1;
72+
// Neither option meets the minimum; prefer single buffer for occupancy
73+
// (total_double >= total_single, so blocks_single >= blocks_double always).
74+
return 1;
75+
}
76+
77+
// ============================================================================
78+
// Vectorized global store/fill helpers (using Vec<> from utils.cuh)
79+
// ============================================================================
80+
81+
template <typename T>
82+
struct VecTraits {
83+
static constexpr int kVecSize = (sizeof(T) <= 16) ? (16 / sizeof(T)) : 1;
84+
};
85+
86+
// Vectorized store: write `count` elements from shmem/registers to global memory.
87+
template <typename T>
88+
__device__ inline void vec_store_global(T *__restrict__ dst, const T *__restrict__ src, int count,
89+
int lane_id) {
90+
constexpr int kVecSize = VecTraits<T>::kVecSize;
91+
using VecType = typename BytesToType<sizeof(T) * kVecSize>::Type;
92+
93+
bool aligned = (reinterpret_cast<uint64_t>(dst) % (sizeof(T) * kVecSize) == 0);
94+
int aligned_count = (count / kVecSize) * kVecSize;
95+
96+
if (aligned && aligned_count > 0) {
97+
int vec_count = aligned_count / kVecSize;
98+
for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) {
99+
VecType v;
100+
T *v_elts = reinterpret_cast<T *>(&v);
101+
#pragma unroll
102+
for (int e = 0; e < kVecSize; e++) {
103+
v_elts[e] = src[vi * kVecSize + e];
104+
}
105+
reinterpret_cast<VecType *>(dst)[vi] = v;
106+
}
107+
for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) {
108+
dst[i] = src[i];
109+
}
110+
} else {
111+
for (int i = lane_id; i < count; i += kThreadsPerWarp) {
112+
dst[i] = src[i];
113+
}
114+
}
115+
}
116+
117+
// Vectorized fill: write `val` to `count` elements of global memory.
118+
template <typename T>
119+
__device__ inline void vec_fill_global(T *__restrict__ dst, T val, int count, int lane_id) {
120+
constexpr int kVecSize = VecTraits<T>::kVecSize;
121+
using VecType = typename BytesToType<sizeof(T) * kVecSize>::Type;
122+
123+
bool aligned = (reinterpret_cast<uint64_t>(dst) % (sizeof(T) * kVecSize) == 0);
124+
int aligned_count = (count / kVecSize) * kVecSize;
125+
126+
if (aligned && aligned_count > 0) {
127+
VecType v;
128+
T *v_elts = reinterpret_cast<T *>(&v);
129+
#pragma unroll
130+
for (int e = 0; e < kVecSize; e++) {
131+
v_elts[e] = val;
132+
}
133+
int vec_count = aligned_count / kVecSize;
134+
for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) {
135+
reinterpret_cast<VecType *>(dst)[vi] = v;
136+
}
137+
for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) {
138+
dst[i] = val;
139+
}
140+
} else {
141+
for (int i = lane_id; i < count; i += kThreadsPerWarp) {
142+
dst[i] = val;
143+
}
144+
}
145+
}
146+
147+
// ============================================================================
148+
// cp.async wrappers — use hardware async copy on sm_80+, no-op on older archs.
149+
// Always defined so callers don't need #if guards.
150+
// ============================================================================
151+
152+
__device__ __forceinline__ void cp_async_16B(void *__restrict__ dst, const void *__restrict__ src) {
153+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
154+
__pipeline_memcpy_async(dst, src, 16);
155+
#else
156+
// Scalar fallback — callers must not rely on this being async.
157+
*static_cast<int4 *>(dst) = *static_cast<const int4 *>(src);
158+
#endif
159+
}
160+
161+
__device__ __forceinline__ void cp_async_commit() {
162+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
163+
__pipeline_commit();
164+
#endif
165+
}
166+
167+
__device__ __forceinline__ void cp_async_wait_all() {
168+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
169+
__pipeline_wait_prior(0);
170+
#endif
171+
}
172+
173+
// ============================================================================
174+
// RawAsyncLoader<T> — double-buffered loader storing data in original type
175+
//
176+
// Enables cp.async for ALL data types (bf16, fp16, fp32) since no type
177+
// conversion is needed during the copy. The kernel reads from shmem and
178+
// casts to CompType during compute.
179+
// ============================================================================
180+
181+
template <typename T>
182+
class RawAsyncLoader {
183+
public:
184+
// Shmem size calculation (usable from both host and device).
185+
static __host__ __device__ inline size_t shmem_bytes(int count, int num_warps, int num_buffers) {
186+
return static_cast<size_t>(num_buffers) * count * num_warps * sizeof(T);
187+
}
188+
189+
// Device-side construction.
190+
__device__ RawAsyncLoader(T *buf_base, int warp_id, int count, int num_warps, int num_buffers)
191+
: phase_(0), double_buf_(num_buffers == 2) {
192+
int per_buffer = count * num_warps;
193+
buf_[0] = buf_base + warp_id * count;
194+
buf_[1] = (num_buffers == 2) ? buf_base + per_buffer + warp_id * count : buf_[0];
195+
}
196+
197+
__device__ __forceinline__ T *current_buf() { return buf_[phase_]; }
198+
__device__ __forceinline__ T *next_buf() { return buf_[phase_ ^ 1]; }
199+
__device__ __forceinline__ void flip() {
200+
if (double_buf_) phase_ ^= 1;
201+
}
202+
203+
// Async load into the NEXT buffer (for prefetching).
204+
__device__ void start_load(const T *__restrict__ src, int count, int lane_id) {
205+
raw_load(src, next_buf(), count, lane_id);
206+
}
207+
208+
// Load into the CURRENT buffer (for the first load before the main loop).
209+
__device__ void load_current(const T *__restrict__ src, int count, int lane_id) {
210+
raw_load(src, current_buf(), count, lane_id);
211+
}
212+
213+
// Wait for pending async loads to complete.
214+
__device__ __forceinline__ void wait() {
215+
cp_async_wait_all();
216+
__syncwarp();
217+
}
218+
219+
private:
220+
T *buf_[2];
221+
int phase_;
222+
bool double_buf_;
223+
224+
// Raw copy: global → shmem, no type conversion.
225+
// Uses 16-byte vectorised copies (cp.async on sm_80+, int4 on older archs)
226+
// when both pointers are 16-byte aligned, with a scalar tail / fallback.
227+
__device__ void raw_load(const T *__restrict__ src, T *__restrict__ dst, int count, int lane_id) {
228+
constexpr int kBytesPerCopy = 16;
229+
constexpr int kEltsPerCopy = kBytesPerCopy / sizeof(T);
230+
231+
bool src_aligned = (reinterpret_cast<uint64_t>(src) % kBytesPerCopy == 0);
232+
bool dst_aligned = (reinterpret_cast<uint64_t>(dst) % kBytesPerCopy == 0);
233+
int aligned_count = (count / kEltsPerCopy) * kEltsPerCopy;
234+
235+
if (src_aligned && dst_aligned && aligned_count > 0) {
236+
int vec_count = aligned_count / kEltsPerCopy;
237+
for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) {
238+
cp_async_16B(dst + vi * kEltsPerCopy, src + vi * kEltsPerCopy);
239+
}
240+
for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) {
241+
dst[i] = src[i];
242+
}
243+
cp_async_commit();
244+
} else {
245+
for (int i = lane_id; i < count; i += kThreadsPerWarp) {
246+
dst[i] = src[i];
247+
}
248+
}
249+
}
250+
};
251+
252+
} // namespace fused_router
253+
} // namespace transformer_engine
254+
255+
#endif // TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_

transformer_engine/common/fused_router/fused_moe_aux_loss.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
6363
const int warp_id = threadIdx.x / kThreadsPerWarp;
6464
const int lane_id = threadIdx.x % kThreadsPerWarp;
6565
if (warp_id == 0) {
66-
CompType block_sum = warp_reduce_on_shmem(shmem_block, static_cast<int>(blockDim.x),
67-
ReduceFuncType::SUM, lane_id);
66+
CompType block_sum = warp_reduce_on_shmem<CompType, ReduceFuncType::SUM>(
67+
shmem_block, static_cast<int>(blockDim.x), lane_id);
6868
if (lane_id == 0) {
6969
atomicAdd(&Coeff_buf[1], static_cast<float>(block_sum * coeff));
7070
}

0 commit comments

Comments
 (0)