|
| 1 | +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 -t f16 -pv --apply-bufferization-pipeline=false --paged-attention --num-pages 32 --page-size 1024 | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK |
| 2 | + |
| 3 | +// CHECK: module attributes {mhal.arch = "[[$ARCH:.*]]"} |
| 4 | + |
| 5 | +// Verify paged attention kernel signature: |
| 6 | +// - Q input: memref<32768xf16> (flattened [1, 1024, 32]) |
| 7 | +// - K page table: memref<32xi64> (page pointers, numPages = 32) |
| 8 | +// - V page table: memref<32xi64> (page pointers, numPages = 32) |
| 9 | +// - Output: memref<32768xf16> (flattened [1, 1024, 32]) |
| 10 | +// CHECK-LABEL: func.func @rock_attention |
| 11 | +// CHECK-SAME: (%[[queriesRaw:.*0]]: memref<32768xf16>, |
| 12 | +// CHECK-SAME: %[[keysPageTable:.*1]]: memref<32xi64>, |
| 13 | +// CHECK-SAME: %[[valuesPageTable:.*2]]: memref<32xi64>, |
| 14 | +// CHECK-SAME: %[[outputRaw:.*3]]: memref<32768xf16>) |
| 15 | +// CHECK-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} |
| 16 | + |
| 17 | +// Transform Q to [G, seq_q, head_qk] |
| 18 | +// CHECK-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<32768xf16> to memref<1x1024x32xf16> |
| 19 | + |
| 20 | +// Transform K page table to [batch, numPages, 1] |
| 21 | +// CHECK-NEXT: %[[keysPageTableTransformed:.*]] = rock.transform %[[keysPageTable]] {{.*}} : memref<32xi64> to memref<1x32x1xi64> |
| 22 | + |
| 23 | +// Transform V page table to [batch, numPages, 1] |
| 24 | +// CHECK-NEXT: %[[valuesPageTableTransformed:.*]] = rock.transform %[[valuesPageTable]] {{.*}} : memref<32xi64> to memref<1x32x1xi64> |
| 25 | + |
| 26 | +// Transform output to [G, seq_q, head_v] |
| 27 | +// CHECK-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<32768xf16> to memref<1x1024x32xf16> |
| 28 | + |
| 29 | +// rock.deref: dereference K page table to get actual K data |
| 30 | +// CHECK-NEXT: %[[keyDeref:.*]] = rock.deref %[[keysPageTableTransformed]] : memref<1x32x1xi64> -> memref<1x32x1024xf16> |
| 31 | + |
| 32 | +// rock.deref: dereference V page table to get actual V data |
| 33 | +// CHECK-NEXT: %[[valueDeref:.*]] = rock.deref %[[valuesPageTableTransformed]] : memref<1x32x1xi64> -> memref<1x32x1024xf16> |
| 34 | + |
| 35 | +// Transform deref'd K to intermediate shapes for attention GEMM |
| 36 | +// CHECK-NEXT: %[[keyTransform1:.*]] = rock.transform %[[keyDeref]] {{.*}} : memref<1x32x1024xf16> to memref<1x32768xf16> |
| 37 | +// CHECK-NEXT: %[[keyTransform2:.*]] = rock.transform %[[keyTransform1]] {{.*}} : memref<1x32768xf16> to memref<1x1x1024x32xf16> |
| 38 | +// CHECK-NEXT: %[[keys:.*]] = rock.transform %[[keyTransform2]] {{.*}} : memref<1x1x1024x32xf16> to memref<1x32x1024xf16> |
| 39 | + |
| 40 | +// Transform deref'd V to intermediate shapes for attention GEMM |
| 41 | +// CHECK-NEXT: %[[valueTransform1:.*]] = rock.transform %[[valueDeref]] {{.*}} : memref<1x32x1024xf16> to memref<1x32768xf16> |
| 42 | +// CHECK-NEXT: %[[valueTransform2:.*]] = rock.transform %[[valueTransform1]] {{.*}} : memref<1x32768xf16> to memref<1x1x1024x32xf16> |
| 43 | +// CHECK-NEXT: %[[values:.*]] = rock.transform %[[valueTransform2]] {{.*}} : memref<1x1x1024x32xf16> to memref<1x1024x32xf16> |
| 44 | + |
| 45 | +// Verify rock.attention op with keyAddresses and valueAddresses attributes |
| 46 | +// CHECK-NEXT: rock.attention |
| 47 | +// CHECK-NEXT: qk = %[[queries]] * %[[keys]] |
| 48 | +// CHECK-NEXT: keyAddresses = (%[[keyDeref]] : memref<1x32x1024xf16>) |
| 49 | +// CHECK-NEXT: valueAddresses = (%[[valueDeref]] : memref<1x32x1024xf16>) |
| 50 | +// CHECK: %[[output]] = softmax(qk) * %[[values]] |
| 51 | +// CHECK: return |
| 52 | + |
| 53 | +// ============================================================================= |
| 54 | +// CPU host function validation for paged attention |
| 55 | +// ============================================================================= |
| 56 | + |
| 57 | +// CHECK-LABEL: func.func @host_naive_attention |
| 58 | +// CHECK-SAME: (%[[hostQ:.*0]]: memref<32768xf16>, |
| 59 | +// CHECK-SAME: %[[hostK:.*1]]: memref<32768xf16>, |
| 60 | +// CHECK-SAME: %[[hostV:.*2]]: memref<32768xf16>, |
| 61 | +// CHECK-SAME: %[[hostOut:.*3]]: memref<32768xf16>) |
| 62 | + |
| 63 | +// Convert Q memref to tensor and reshape to [1, 1024, 32] |
| 64 | +// CHECK: bufferization.to_tensor %[[hostQ]] |
| 65 | +// CHECK: tosa.reshape {{.*}} : (tensor<32768xf16>, !tosa.shape<3>) -> tensor<1x1024x32xf16> |
| 66 | + |
| 67 | +// Reshape K to [1, 32, 1024] for Q*K^T matmul |
| 68 | +// CHECK: bufferization.to_tensor %[[hostK]] |
| 69 | +// CHECK: tosa.reshape {{.*}} : (tensor<32768xf16>, !tosa.shape<3>) -> tensor<1x32x1024xf16> |
| 70 | + |
| 71 | +// Reshape V to [1, 1024, 32] |
| 72 | +// CHECK: bufferization.to_tensor %[[hostV]] |
| 73 | +// CHECK: tosa.reshape {{.*}} : (tensor<32768xf16>, !tosa.shape<3>) -> tensor<1x1024x32xf16> |
| 74 | + |
| 75 | +// First matmul: Q * K^T -> [1, 1024, 1024] |
| 76 | +// CHECK: tosa.matmul {{.*}} {acc_type = f32} : (tensor<1x1024x32xf16>, tensor<1x32x1024xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x1024x1024xf16> |
| 77 | + |
| 78 | +// Softmax: cast to f32, reduce_max, sub, exp, reduce_sum, reciprocal, mul, cast back |
| 79 | +// CHECK: tosa.cast {{.*}} : (tensor<1x1024x1024xf16>) -> tensor<1x1024x1024xf32> |
| 80 | +// CHECK: tosa.reduce_max {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1xf32> |
| 81 | +// CHECK: tosa.sub {{.*}} : (tensor<1x1024x1024xf32>, tensor<1x1024x1xf32>) -> tensor<1x1024x1024xf32> |
| 82 | +// CHECK: tosa.exp {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1024xf32> |
| 83 | +// CHECK: tosa.reduce_sum {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1xf32> |
| 84 | +// CHECK: tosa.reciprocal {{.*}} : (tensor<1x1024x1xf32>) -> tensor<1x1024x1xf32> |
| 85 | +// CHECK: tosa.mul {{.*}} : (tensor<1x1024x1024xf32>, tensor<1x1024x1xf32>, tensor<1xi8>) -> tensor<1x1024x1024xf32> |
| 86 | +// CHECK: tosa.cast {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1024xf16> |
| 87 | + |
| 88 | +// Second matmul: softmax(Q*K^T) * V -> [1, 1024, 32] |
| 89 | +// CHECK: tosa.matmul {{.*}} {acc_type = f32} : (tensor<1x1024x1024xf16>, tensor<1x1024x32xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x1024x32xf16> |
| 90 | + |
| 91 | +// Reshape output and copy to result |
| 92 | +// CHECK: tosa.reshape {{.*}} : (tensor<1x1024x32xf16>, !tosa.shape<1>) -> tensor<32768xf16> |
| 93 | +// CHECK: bufferization.to_buffer |
| 94 | +// CHECK: memref.copy |
| 95 | +// CHECK: return |
| 96 | + |
| 97 | +// ---- |
| 98 | + |
| 99 | +// Test paged attention with GQA (grouped query attention) |
| 100 | +// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 -t f16 -pv --apply-bufferization-pipeline=false --paged-attention --num-pages 64 --page-size 1024 | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK_GQA |
| 101 | + |
| 102 | +// CHECK_GQA: module attributes {mhal.arch = "[[$ARCH:.*]]"} |
| 103 | + |
| 104 | +// Verify GQA paged attention kernel signature: |
| 105 | +// - Q input: memref<131072xf16> (flattened [4, 1024, 32]) |
| 106 | +// - K page table: memref<64xi64> (page pointers for 2 heads) |
| 107 | +// - V page table: memref<64xi64> (page pointers for 2 heads) |
| 108 | +// - Output: memref<131072xf16> (flattened [4, 1024, 32]) |
| 109 | +// CHECK_GQA-LABEL: func.func @rock_attention |
| 110 | +// CHECK_GQA-SAME: (%[[queriesRaw:.*0]]: memref<131072xf16>, |
| 111 | +// CHECK_GQA-SAME: %[[keysPageTable:.*1]]: memref<64xi64>, |
| 112 | +// CHECK_GQA-SAME: %[[valuesPageTable:.*2]]: memref<64xi64>, |
| 113 | +// CHECK_GQA-SAME: %[[outputRaw:.*3]]: memref<131072xf16>) |
| 114 | +// CHECK_GQA-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"} |
| 115 | + |
| 116 | +// Transform Q to [G, seq_q, head_qk] with G = num_heads_q = 4 |
| 117 | +// CHECK_GQA-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<131072xf16> to memref<4x1024x32xf16> |
| 118 | + |
| 119 | +// Transform K page table |
| 120 | +// CHECK_GQA-NEXT: %[[keysPageTableTransformed:.*]] = rock.transform %[[keysPageTable]] {{.*}} : memref<64xi64> to memref<1x64x1xi64> |
| 121 | + |
| 122 | +// Transform V page table |
| 123 | +// CHECK_GQA-NEXT: %[[valuesPageTableTransformed:.*]] = rock.transform %[[valuesPageTable]] {{.*}} : memref<64xi64> to memref<1x64x1xi64> |
| 124 | + |
| 125 | +// Transform output |
| 126 | +// CHECK_GQA-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<131072xf16> to memref<4x1024x32xf16> |
| 127 | + |
| 128 | +// rock.deref K |
| 129 | +// CHECK_GQA-NEXT: %[[keyDeref:.*]] = rock.deref %[[keysPageTableTransformed]] : memref<1x64x1xi64> -> memref<1x64x1024xf16> |
| 130 | + |
| 131 | +// rock.deref V |
| 132 | +// CHECK_GQA-NEXT: %[[valueDeref:.*]] = rock.deref %[[valuesPageTableTransformed]] : memref<1x64x1xi64> -> memref<1x64x1024xf16> |
| 133 | + |
| 134 | +// K transforms to [G, head_dim_qk, seq_k] with G = num_heads_kv = 2 |
| 135 | +// CHECK_GQA: %[[keys:.*]] = rock.transform %{{.*}} {{.*}} to memref<2x32x1024xf16> |
| 136 | + |
| 137 | +// V transforms to [G, seq_k, head_dim_v] with G = num_heads_kv = 2 |
| 138 | +// CHECK_GQA: %[[values:.*]] = rock.transform %{{.*}} {{.*}} to memref<2x1024x32xf16> |
| 139 | + |
| 140 | +// Verify rock.attention op with GQA and paged attention |
| 141 | +// CHECK_GQA: rock.attention |
| 142 | +// CHECK_GQA-NEXT: qk = %[[queries]] * %[[keys]] |
| 143 | +// CHECK_GQA-NEXT: keyAddresses = (%[[keyDeref]] : memref<1x64x1024xf16>) |
| 144 | +// CHECK_GQA-NEXT: valueAddresses = (%[[valueDeref]] : memref<1x64x1024xf16>) |
| 145 | +// CHECK_GQA: %[[output]] = softmax(qk) * %[[values]] |
| 146 | +// CHECK_GQA-NEXT: numHeadsKV = 2 : i32, numHeadsQ = 4 : i32 |
| 147 | +// CHECK_GQA: return |
0 commit comments