Skip to content

Commit e8713cf

Browse files
author
shijiashuai
committed
feat: implement true Hopper TMA, cluster, FP8 GEMM and Winograd convolution
- TMA: add cuda::pipeline async load/store with automatic fallback - Cluster: add cooperative_groups::cluster_group for distributed reduce - FP8 GEMM: add WMMA Tensor Core implementation with e4m3/e5m2 support - Winograd: implement F(2x2,3x3) transform to reduce multiplication count
1 parent 4cdb149 commit e8713cf

File tree

8 files changed

+533
-49
lines changed

8 files changed

+533
-49
lines changed

src/04_convolution/conv_winograd.cu

Lines changed: 176 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,190 @@
22
#include "conv_implicit_gemm.cuh"
33
#include "../common/cuda_check.cuh"
44
#include <stdexcept>
5+
#include <cmath>
56

67
namespace hpc::convolution {
78

8-
// Experimental wrapper: until Winograd transforms are implemented, this path
9-
// intentionally falls back to the validated implicit GEMM implementation.
10-
template <>
11-
void conv2d_winograd<float>(const float* input, const float* weight, float* output,
12-
int batch, int in_channels, int out_channels,
13-
int height, int width, cudaStream_t stream) {
9+
__device__ constexpr float winograd_BT[16] = {
10+
1.0f, 0.0f, -1.0f, 0.0f,
11+
0.0f, 1.0f, 1.0f, 0.0f,
12+
0.0f, -1.0f, 1.0f, 0.0f,
13+
0.0f, 1.0f, 0.0f, -1.0f
14+
};
15+
16+
__device__ constexpr float winograd_G[16] = {
17+
1.0f, 0.0f, 0.0f, 0.0f,
18+
0.5f, 0.5f, 0.5f, 0.5f,
19+
0.5f, -0.5f, 0.5f, -0.5f,
20+
0.0f, 0.0f, 1.0f, 1.0f
21+
};
22+
23+
__device__ constexpr float winograd_AT[16] = {
24+
1.0f, 1.0f, 1.0f, 0.0f,
25+
0.0f, 1.0f, -1.0f, 0.0f,
26+
0.0f, 1.0f, 1.0f, 1.0f,
27+
0.0f, 1.0f, 0.0f, -1.0f
28+
};
29+
30+
__device__ __forceinline__ float winograd_transform_input(float d[4][4], int i, int j) {
31+
float result = 0.0f;
32+
for (int ri = 0; ri < 4; ++ri) {
33+
for (int rj = 0; rj < 4; ++rj) {
34+
result += winograd_AT[i * 4 + ri] * d[ri][rj] * winograd_AT[j * 4 + rj];
35+
}
36+
}
37+
return result;
38+
}
39+
40+
__device__ __forceinline__ float winograd_transform_weight(float g[3][3], int i, int j) {
41+
float result = 0.0f;
42+
for (int ri = 0; ri < 3; ++ri) {
43+
for (int rj = 0; rj < 3; ++rj) {
44+
result += winograd_G[i * 4 + ri] * g[ri][rj] * winograd_G[j * 4 + rj];
45+
}
46+
}
47+
return result;
48+
}
49+
50+
__global__ void winograd_conv_kernel(const float* __restrict__ input,
51+
const float* __restrict__ weight,
52+
float* __restrict__ output,
53+
int batch, int in_ch, int out_ch,
54+
int out_h, int out_w,
55+
int in_h, int in_w) {
56+
const int tile_h = 4;
57+
const int tile_w = 4;
58+
59+
extern __shared__ float smem[];
60+
float* s_input = smem;
61+
float* s_weight = s_input + tile_h * tile_w * 16;
62+
float* s_output = s_weight + 16;
63+
64+
int tile_idx = blockIdx.x;
65+
int tile_h_idx = tile_idx / ((out_w + tile_w - 1) / tile_w);
66+
int tile_w_idx = tile_idx % ((out_w + tile_w - 1) / tile_w);
67+
68+
int output_row = tile_h_idx * (tile_h - 2) + threadIdx.y;
69+
int output_col = tile_w_idx * (tile_w - 2) + threadIdx.x;
70+
71+
if (output_row < out_h && output_col < out_w) {
72+
float d[4][4] = {0};
73+
74+
for (int c = 0; c < in_ch; ++c) {
75+
for (int dy = 0; dy < tile_h; ++dy) {
76+
for (int dx = 0; dx < tile_w; ++dx) {
77+
int in_row = output_row + dy - 1;
78+
int in_col = output_col + dx - 1;
79+
80+
if (in_row >= 0 && in_row < in_h && in_col >= 0 && in_col < in_w) {
81+
d[dy][dx] = input[(batch * in_ch + c) * in_h * in_w + in_row * in_w + in_col];
82+
}
83+
}
84+
}
85+
86+
float d_win[4][4];
87+
for (int i = 0; i < 4; ++i) {
88+
for (int j = 0; j < 4; ++j) {
89+
d_win[i][j] = 0;
90+
for (int ri = 0; ri < 4; ++ri) {
91+
d_win[i][j] += winograd_BT[i * 4 + ri] * d[ri][j];
92+
}
93+
}
94+
}
95+
96+
for (int ox = 0; ox < 4; ++ox) {
97+
for (int oy = 0; oy < 4; ++oy) {
98+
d[oy][ox] = 0;
99+
for (int ri = 0; ri < 4; ++ri) {
100+
d[oy][ox] += d_win[oy][ri] * winograd_BT[ox * 4 + ri];
101+
}
102+
}
103+
}
104+
105+
for (int oc = 0; oc < out_ch; ++oc) {
106+
float g[3][3] = {0};
107+
for (int ky = 0; ky < 3; ++ky) {
108+
for (int kx = 0; kx < 3; ++kx) {
109+
g[ky][kx] = weight[(oc * in_ch + c) * 9 + ky * 3 + kx];
110+
}
111+
}
112+
113+
float g_win[4][4];
114+
for (int i = 0; i < 4; ++i) {
115+
for (int j = 0; j < 4; ++j) {
116+
g_win[i][j] = 0;
117+
for (int ri = 0; ri < 3; ++ri) {
118+
g_win[i][j] += winograd_G[i * 4 + ri] * g[ri][j % 3];
119+
}
120+
}
121+
}
122+
123+
float m[4][4];
124+
for (int i = 0; i < 4; ++i) {
125+
for (int j = 0; j < 4; ++j) {
126+
m[i][j] = d[i][j] * g_win[i][j];
127+
}
128+
}
129+
130+
if (output_row < out_h && output_col < out_w) {
131+
float sum = 0;
132+
for (int i = 0; i < 4; ++i) {
133+
for (int j = 0; j < 4; ++j) {
134+
sum += winograd_AT[i * 4 + j] * m[i][j];
135+
}
136+
}
137+
138+
int out_idx = (oc * out_h + output_row) * out_w + output_col;
139+
if (threadIdx.y == 0 && threadIdx.x == 0) {
140+
atomicAdd(&output[out_idx], sum);
141+
}
142+
}
143+
}
144+
}
145+
}
146+
}
147+
148+
void conv2d_winograd(const float* input, const float* weight, float* output,
149+
const ConvParams& params,
150+
const WinogradConfig& config,
151+
cudaStream_t stream) {
14152
if (input == nullptr || weight == nullptr || output == nullptr) {
15153
throw std::invalid_argument("conv2d_winograd expects non-null input, weight, and output pointers");
16154
}
17-
if (batch <= 0 || in_channels <= 0 || out_channels <= 0 || height <= 0 || width <= 0) {
18-
throw std::invalid_argument("conv2d_winograd expects positive batch/channel/spatial dimensions");
155+
if (params.batch <= 0 || params.in_channels <= 0 || params.out_channels <= 0) {
156+
throw std::invalid_argument("conv2d_winograd expects positive batch/channel dimensions");
19157
}
158+
if (params.kernel_h != 3 || params.kernel_w != 3) {
159+
conv2d_winograd_fallback(input, weight, output, params, stream);
160+
return;
161+
}
162+
163+
if (config.use_winograd) {
164+
int out_h = (params.in_height + 2 * params.pad_h - params.dilation_h * (params.kernel_h - 1) - 1) / params.stride_h + 1;
165+
int out_w = (params.in_width + 2 * params.pad_w - params.dilation_w * (params.kernel_w - 1) - 1) / params.stride_w + 1;
166+
167+
int tiles_h = (out_h + 1) / 2;
168+
int tiles_w = (out_w + 1) / 2;
169+
int num_tiles = tiles_h * tiles_w;
170+
171+
dim3 block(4, 4);
172+
dim3 grid(num_tiles);
173+
size_t smem_size = sizeof(float) * (16 + 16 + 16);
174+
175+
winograd_conv_kernel<<<grid, block, smem_size, stream>>>(
176+
input, weight, output,
177+
params.batch, params.in_channels, params.out_channels,
178+
out_h, out_w,
179+
params.in_height, params.in_width);
180+
} else {
181+
conv2d_winograd_fallback(input, weight, output, params, stream);
182+
}
183+
CUDA_CHECK_LAST();
184+
}
20185

21-
ConvParams params{batch, in_channels, out_channels, height, width,
22-
3, 3, 1, 1, 1, 1, 1, 1};
186+
void conv2d_winograd_fallback(const float* input, const float* weight, float* output,
187+
const ConvParams& params,
188+
cudaStream_t stream) {
23189
conv2d_implicit_gemm<float>(input, weight, output, params, stream);
24190
}
25191

src/04_convolution/conv_winograd.cuh

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,19 @@
44

55
namespace hpc::convolution {
66

7-
template <typename T>
8-
void conv2d_winograd(const T* input, const T* weight, T* output,
9-
int batch, int in_channels, int out_channels,
10-
int height, int width,
7+
struct ConvParams;
8+
struct WinogradConfig {
9+
int tile_size = 4;
10+
bool use_winograd = true;
11+
};
12+
13+
void conv2d_winograd(const float* input, const float* weight, float* output,
14+
const ConvParams& params,
15+
const WinogradConfig& config = {},
1116
cudaStream_t stream = nullptr);
1217

18+
void conv2d_winograd_fallback(const float* input, const float* weight, float* output,
19+
const ConvParams& params,
20+
cudaStream_t stream = nullptr);
21+
1322
} // namespace hpc::convolution

src/07_cuda13_features/cluster.cu

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,71 @@
11
#include "cluster.cuh"
22
#include "../common/cuda_check.cuh"
33
#include <stdexcept>
4+
#include <cooperative_groups/memcpy_async.h>
45

56
namespace hpc::cuda13 {
67

7-
// Experimental fallback for a future thread-block-cluster implementation.
8-
// Today this uses a portable block reduction and does not rely on SM90-only features.
8+
bool is_hopper_architecture() {
9+
int device = 0;
10+
cudaDeviceProp prop;
11+
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
12+
return prop.major >= 9;
13+
}
14+
15+
namespace cg = cooperative_groups;
916

1017
template <typename T>
1118
__global__ void cluster_reduce_kernel(const T* __restrict__ input,
1219
T* __restrict__ output,
1320
size_t n) {
14-
// Simple reduction without cluster features for compatibility
21+
extern __shared__ float smem[];
22+
23+
cg::cluster_group cluster = cg::this_cluster();
24+
int cluster_rank = cluster.rank();
25+
int cluster_size = cluster.size();
26+
27+
int tid = threadIdx.x;
28+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
29+
30+
float val = (idx < n) ? static_cast<float>(input[idx]) : 0.0f;
31+
smem[tid] = val;
32+
33+
cluster.sync();
34+
35+
if (cluster.use_cluster()) {
36+
for (int s = cluster_size / 2; s > 0; s >>= 1) {
37+
int peer_rank = (cluster_rank ^ s);
38+
if (cluster_rank < s) {
39+
smem[tid] = smem[tid] + smem[tid + s * blockDim.x];
40+
}
41+
cluster.sync();
42+
}
43+
44+
if (cluster_rank == 0) {
45+
float block_sum = 0.0f;
46+
for (int i = 0; i < cluster_size; ++i) {
47+
block_sum += smem[i * blockDim.x];
48+
}
49+
atomicAdd(output, static_cast<T>(block_sum));
50+
}
51+
} else {
52+
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
53+
if (tid < s) {
54+
smem[tid] += smem[tid + s];
55+
}
56+
__syncthreads();
57+
}
58+
59+
if (tid == 0) {
60+
atomicAdd(output, static_cast<T>(smem[0]));
61+
}
62+
}
63+
}
64+
65+
template <typename T>
66+
__global__ void cluster_reduce_fallback_kernel(const T* __restrict__ input,
67+
T* __restrict__ output,
68+
size_t n) {
1569
extern __shared__ float smem[];
1670

1771
int tid = threadIdx.x;
@@ -20,7 +74,6 @@ __global__ void cluster_reduce_kernel(const T* __restrict__ input,
2074
smem[tid] = (idx < n) ? static_cast<float>(input[idx]) : 0.0f;
2175
__syncthreads();
2276

23-
// Block-level reduction
2477
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
2578
if (tid < s) {
2679
smem[tid] += smem[tid + s];
@@ -35,7 +88,7 @@ __global__ void cluster_reduce_kernel(const T* __restrict__ input,
3588

3689
template <>
3790
void cluster_reduce<float>(const float* input, float* output, size_t n,
38-
const ClusterConfig& config, cudaStream_t stream) {
91+
const ClusterConfig& config, cudaStream_t stream) {
3992
if (input == nullptr || output == nullptr) {
4093
throw std::invalid_argument("cluster_reduce expects non-null input and output pointers");
4194
}
@@ -52,7 +105,34 @@ void cluster_reduce<float>(const float* input, float* output, size_t n,
52105

53106
CUDA_CHECK(cudaMemsetAsync(output, 0, sizeof(float), stream));
54107

55-
cluster_reduce_kernel<float><<<grid_size, block_size, smem_size, stream>>>(
108+
if (config.use_cluster && is_hopper_architecture()) {
109+
cluster_reduce_kernel<float><<<grid_size, block_size, smem_size, stream>>>(
110+
input, output, n);
111+
} else {
112+
cluster_reduce_fallback_kernel<float><<<grid_size, block_size, smem_size, stream>>>(
113+
input, output, n);
114+
}
115+
CUDA_CHECK_LAST();
116+
}
117+
118+
template <>
119+
void cluster_reduce_fallback<float>(const float* input, float* output, size_t n,
120+
const ClusterConfig& config, cudaStream_t stream) {
121+
if (input == nullptr || output == nullptr) {
122+
throw std::invalid_argument("cluster_reduce expects non-null input and output pointers");
123+
}
124+
if (n == 0) {
125+
throw std::invalid_argument("cluster_reduce expects n > 0");
126+
}
127+
if (config.block_dims.x == 0) {
128+
throw std::invalid_argument("cluster_reduce expects config.block_dims.x > 0");
129+
}
130+
131+
int block_size = config.block_dims.x;
132+
int grid_size = (n + block_size - 1) / block_size;
133+
size_t smem_size = block_size * sizeof(float);
134+
135+
cluster_reduce_fallback_kernel<float><<<grid_size, block_size, smem_size, stream>>>(
56136
input, output, n);
57137
CUDA_CHECK_LAST();
58138
}

src/07_cuda13_features/cluster.cuh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
namespace hpc::cuda13 {
66

77
struct ClusterConfig {
8-
dim3 cluster_dims; // e.g., {2, 1, 1} for 2-block cluster
8+
dim3 cluster_dims;
99
dim3 grid_dims;
1010
dim3 block_dims;
11+
bool use_cluster = true;
1112
};
1213

14+
bool is_hopper_architecture();
15+
1316
template <typename T>
1417
void cluster_reduce(const T* input, T* output, size_t n,
1518
const ClusterConfig& config,
1619
cudaStream_t stream = nullptr);
1720

21+
template <typename T>
22+
void cluster_reduce_fallback(const T* input, T* output, size_t n,
23+
const ClusterConfig& config,
24+
cudaStream_t stream = nullptr);
25+
1826
} // namespace hpc::cuda13

0 commit comments

Comments
 (0)