@@ -1127,6 +1127,200 @@ INSTANTIATE_HADAMARD(256)
11271127
11281128#undef INSTANTIATE_HADAMARD
11291129
1130+ // ===========================================================================
1131+ // Full-dimension Hadamard rotation kernel.
1132+ // One thread block processes one row of DIM elements using 3-4 butterfly levels:
1133+ // 1. In-thread butterfly (strides 1..kNElts/2)
1134+ // 2. Warp shuffle butterfly (strides kNElts..kNElts*16)
1135+ // 3. Cross-warp butterfly via shared memory (strides across warps)
1136+ // 4. Cross-chunk butterfly in registers (when kNChunks > 1)
1137+ //
1138+ // Grid: (num_rows,). Signs: DIM/32 uint32 words (one per full row, not per block).
1139+ // ===========================================================================
1140+
1141+ template <int kLogDim , int kNThreads , typename T>
1142+ __global__ void kHadamardRotateFull (T* __restrict__ data, const int num_rows, const unsigned int * __restrict__ signs) {
1143+ constexpr int DIM = 1 << kLogDim ;
1144+ constexpr int kNElts = 8 ; // elements per thread per chunk
1145+ constexpr int kNChunks = DIM / (kNThreads * kNElts );
1146+ constexpr int kNWarps = kNThreads / 32 ;
1147+
1148+ static_assert (DIM == kNThreads * kNElts * kNChunks , " dimension decomposition mismatch" );
1149+ static_assert (kNElts == 8 , " kNElts must be 8" );
1150+ static_assert ((kNThreads & (kNThreads - 1 )) == 0 , " kNThreads must be power of 2" );
1151+
1152+ const int row = blockIdx .x ;
1153+ if (row >= num_rows)
1154+ return ;
1155+
1156+ T* row_data = data + (long long )row * DIM;
1157+
1158+ // Shared memory for cross-warp butterfly (only needed when kNWarps > 1).
1159+ // Use char[] to match other kernels in this TU, then cast to float*.
1160+ extern __shared__ char smem_raw[];
1161+ float * smem = reinterpret_cast <float *>(smem_raw);
1162+
1163+ const int tid = threadIdx .x ;
1164+ const int warp_id = tid / 32 ;
1165+ const int lane_id = tid % 32 ;
1166+
1167+ // ---- Load elements (contiguous per thread) ----
1168+ float vals[kNChunks ][kNElts ];
1169+ #pragma unroll
1170+ for (int c = 0 ; c < kNChunks ; c++) {
1171+ const int base = c * kNThreads * kNElts + tid * kNElts ;
1172+ #pragma unroll
1173+ for (int i = 0 ; i < kNElts ; i++) {
1174+ vals[c][i] = (float )row_data[base + i];
1175+ }
1176+ }
1177+
1178+ // ---- Apply sign flips (D matrix) before butterfly ----
1179+ // 8 contiguous elements at position 'base' always fit within one uint32 word
1180+ // since base is always a multiple of 8.
1181+ if (signs != nullptr ) {
1182+ #pragma unroll
1183+ for (int c = 0 ; c < kNChunks ; c++) {
1184+ const int linear = c * kNThreads + tid; // which group of 8
1185+ const int word_idx = linear / 4 ;
1186+ const int byte_pos = (linear % 4 ) * 8 ;
1187+ const unsigned int byte_bits = (signs[word_idx] >> byte_pos) & 0xFFu ;
1188+ #pragma unroll
1189+ for (int i = 0 ; i < kNElts ; i++) {
1190+ if (byte_bits & (1u << i))
1191+ vals[c][i] = -vals[c][i];
1192+ }
1193+ }
1194+ }
1195+
1196+ // ---- Level 1: In-thread butterfly (strides 1, 2, 4) ----
1197+ #pragma unroll
1198+ for (int c = 0 ; c < kNChunks ; c++) {
1199+ #pragma unroll
1200+ for (int s = 1 ; s < kNElts ; s <<= 1 ) {
1201+ #pragma unroll
1202+ for (int i = 0 ; i < kNElts ; i++) {
1203+ int partner = i ^ s;
1204+ if (partner > i) {
1205+ float a = vals[c][i], b = vals[c][partner];
1206+ vals[c][i] = a + b;
1207+ vals[c][partner] = a - b;
1208+ }
1209+ }
1210+ }
1211+ }
1212+
1213+ // ---- Level 2: Warp shuffle butterfly (shfl_xor s=1..16) ----
1214+ #pragma unroll
1215+ for (int s = 1 ; s <= 16 ; s <<= 1 ) {
1216+ #pragma unroll
1217+ for (int c = 0 ; c < kNChunks ; c++) {
1218+ #pragma unroll
1219+ for (int i = 0 ; i < kNElts ; i++) {
1220+ float other = __shfl_xor_sync (0xFFFFFFFF , vals[c][i], s);
1221+ vals[c][i] = (lane_id & s) ? (other - vals[c][i]) : (vals[c][i] + other);
1222+ }
1223+ }
1224+ }
1225+
1226+ // ---- Level 3: Cross-warp butterfly via shared memory ----
1227+ if constexpr (kNWarps > 1 ) {
1228+ constexpr int VALS_PER_THREAD = kNChunks * kNElts ;
1229+ // smem layout: smem[tid * VALS_PER_THREAD + c * kNElts + i]
1230+ #pragma unroll
1231+ for (int ws = 1 ; ws < kNWarps ; ws <<= 1 ) {
1232+ // Write my values to shared memory
1233+ #pragma unroll
1234+ for (int c = 0 ; c < kNChunks ; c++) {
1235+ #pragma unroll
1236+ for (int i = 0 ; i < kNElts ; i++) {
1237+ smem[tid * VALS_PER_THREAD + c * kNElts + i] = vals[c][i];
1238+ }
1239+ }
1240+ __syncthreads ();
1241+
1242+ // Read partner warp's values
1243+ const int partner_tid = (warp_id ^ ws) * 32 + lane_id;
1244+ const bool negate = (warp_id & ws) != 0 ;
1245+ #pragma unroll
1246+ for (int c = 0 ; c < kNChunks ; c++) {
1247+ #pragma unroll
1248+ for (int i = 0 ; i < kNElts ; i++) {
1249+ float pval = smem[partner_tid * VALS_PER_THREAD + c * kNElts + i];
1250+ vals[c][i] = negate ? (pval - vals[c][i]) : (vals[c][i] + pval);
1251+ }
1252+ }
1253+ __syncthreads ();
1254+ }
1255+ }
1256+
1257+ // ---- Level 4: Cross-chunk butterfly (in-register, no communication) ----
1258+ if constexpr (kNChunks > 1 ) {
1259+ #pragma unroll
1260+ for (int cs = 1 ; cs < kNChunks ; cs <<= 1 ) {
1261+ #pragma unroll
1262+ for (int c = 0 ; c < kNChunks ; c++) {
1263+ int pc = c ^ cs;
1264+ if (pc > c) {
1265+ #pragma unroll
1266+ for (int i = 0 ; i < kNElts ; i++) {
1267+ float a = vals[c][i], b = vals[pc][i];
1268+ vals[c][i] = a + b;
1269+ vals[pc][i] = a - b;
1270+ }
1271+ }
1272+ }
1273+ }
1274+ }
1275+
1276+ // ---- Normalize by 1/sqrt(DIM) ----
1277+ const float norm = rsqrtf ((float )DIM);
1278+ #pragma unroll
1279+ for (int c = 0 ; c < kNChunks ; c++) {
1280+ #pragma unroll
1281+ for (int i = 0 ; i < kNElts ; i++)
1282+ vals[c][i] *= norm;
1283+ }
1284+
1285+ // ---- Store back ----
1286+ #pragma unroll
1287+ for (int c = 0 ; c < kNChunks ; c++) {
1288+ const int base = c * kNThreads * kNElts + tid * kNElts ;
1289+ #pragma unroll
1290+ for (int i = 0 ; i < kNElts ; i++) {
1291+ row_data[base + i] = (T)vals[c][i];
1292+ }
1293+ }
1294+ }
1295+
1296+ // ---- Full-dimension Hadamard launch wrapper ----
1297+ // kLogDim must match the dimension. kNThreads is the thread block size.
1298+
1299+ template <int kLogDim , int kNThreads , typename T>
1300+ void hadamardRotateFull (T* data, int num_rows, const unsigned int * signs, cudaStream_t stream) {
1301+ constexpr int DIM = 1 << kLogDim ;
1302+ constexpr int kNElts = 8 ;
1303+ constexpr int kNChunks = DIM / (kNThreads * kNElts );
1304+ constexpr int smem_bytes = kNThreads * kNChunks * kNElts * sizeof (float );
1305+ kHadamardRotateFull <kLogDim , kNThreads , T><<<num_rows, kNThreads , smem_bytes, stream>>> (data, num_rows, signs);
1306+ CUDA_CHECK_RETURN (cudaPeekAtLastError ());
1307+ }
1308+
1309+ // Explicit instantiations: dim 512..8192, 2 dtypes
1310+ #define INSTANTIATE_HADAMARD_FULL (LOG_DIM, NTHREADS ) \
1311+ template void hadamardRotateFull<LOG_DIM, NTHREADS, half>(half*, int , const unsigned int *, cudaStream_t); \
1312+ template void hadamardRotateFull<LOG_DIM, NTHREADS, __nv_bfloat16>( \
1313+ __nv_bfloat16*, int , const unsigned int *, cudaStream_t \
1314+ );
1315+
1316+ INSTANTIATE_HADAMARD_FULL (9 , 64 ) // dim=512
1317+ INSTANTIATE_HADAMARD_FULL(10 , 128 ) // dim=1024
1318+ INSTANTIATE_HADAMARD_FULL(11 , 256 ) // dim=2048
1319+ INSTANTIATE_HADAMARD_FULL(12 , 256 ) // dim=4096
1320+ INSTANTIATE_HADAMARD_FULL(13 , 256 ) // dim=8192
1321+
1322+ #undef INSTANTIATE_HADAMARD_FULL
1323+
11301324// Datacenter GPU detection: Hopper (sm_90) and Blackwell datacenter (sm_100).
11311325// NOTE: sm_120 (RTX 5090, Blackwell consumer) lacks TMA/wgmma — must NOT match.
11321326#if defined(__CUDA_ARCH__)
0 commit comments