Skip to content

Commit 020797c

Browse files
committed
rocmlir-gen changes
1 parent 034180c commit 020797c

10 files changed

Lines changed: 765 additions & 27 deletions

mlir/test/e2e/AttentionNonPowerOfTwoTileSize.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea
8181
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding)
8282
[[suite.test]]
8383
config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias"
84+
85+
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention
86+
[[suite.test]]
87+
config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192"

mlir/test/e2e/AttentionSchedule.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea
116116
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding)
117117
[[suite.test]]
118118
config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2"
119+
120+
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention
121+
[[suite.test]]
122+
config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias -schedule_version 2 --paged-attention --num-pages 6 --page-size 8192"

mlir/test/e2e/PrAttentionBF16.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,7 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea
119119
# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding)
120120
[[suite.test]]
121121
config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias"
122+
123+
# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention
124+
[[suite.test]]
125+
config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192"

mlir/test/e2e/PrAttentionDirectToLDS.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea
2626
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding)
2727
[[suite.test]]
2828
config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 --prefix_offset=16,14,12 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias"
29+
30+
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention
31+
[[suite.test]]
32+
config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 --prefix_offset=16,14,12 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192"

mlir/test/e2e/PrAttentionF16.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,7 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea
119119
# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding)
120120
[[suite.test]]
121121
config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias"
122+
123+
# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention
124+
[[suite.test]]
125+
config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192"

mlir/test/e2e/PrAttentionF32.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea
9191
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding)
9292
[[suite.test]]
9393
config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias"
94+
95+
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention
96+
[[suite.test]]
97+
config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192"

mlir/test/e2e/PrAttentionI8.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,7 @@ config = "-rand 1 -return_lse -split_kv 4 -current_seq_len=17,1,32 -g 3 -num_hea
9696
# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding)
9797
[[suite.test]]
9898
config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias"
99+
100+
# GQA + causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention
101+
[[suite.test]]
102+
config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --paged-attention --num-pages 6 --page-size 8192"

mlir/test/e2e/PrAttentionSchedule.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ config = "-rand 1 -return_lse -split_kv 8 -current_seq_len=17,1,32 -g 3 -num_hea
2828
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding)
2929
[[suite.test]]
3030
config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2"
31+
32+
# GQA + prefix causal + KV Cache batch=3 + return LSE + split-kv (padding) + paged attention
33+
[[suite.test]]
34+
config = "-rand 1 -return_lse -split_kv 8 -prefix_offset=18,5,16 -current_seq_len=17,1,32 -g 3 -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 -causal --with-attn-scale --with-attn-bias --schedule_version 2 --paged-attention --num-pages 6 --page-size 8192"
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 -t f16 -pv --apply-bufferization-pipeline=false --paged-attention --num-pages 32 --page-size 1024 | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK
2+
3+
// CHECK: module attributes {mhal.arch = "[[$ARCH:.*]]"}
4+
5+
// Verify paged attention kernel signature:
6+
// - Q input: memref<32768xf16> (flattened [1, 1024, 32])
7+
// - K page table: memref<32xi64> (page pointers, numPages = 32)
8+
// - V page table: memref<32xi64> (page pointers, numPages = 32)
9+
// - Output: memref<32768xf16> (flattened [1, 1024, 32])
10+
// CHECK-LABEL: func.func @rock_attention
11+
// CHECK-SAME: (%[[queriesRaw:.*0]]: memref<32768xf16>,
12+
// CHECK-SAME: %[[keysPageTable:.*1]]: memref<32xi64>,
13+
// CHECK-SAME: %[[valuesPageTable:.*2]]: memref<32xi64>,
14+
// CHECK-SAME: %[[outputRaw:.*3]]: memref<32768xf16>)
15+
// CHECK-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"}
16+
17+
// Transform Q to [G, seq_q, head_qk]
18+
// CHECK-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<32768xf16> to memref<1x1024x32xf16>
19+
20+
// Transform K page table to [batch, numPages, 1]
21+
// CHECK-NEXT: %[[keysPageTableTransformed:.*]] = rock.transform %[[keysPageTable]] {{.*}} : memref<32xi64> to memref<1x32x1xi64>
22+
23+
// Transform V page table to [batch, numPages, 1]
24+
// CHECK-NEXT: %[[valuesPageTableTransformed:.*]] = rock.transform %[[valuesPageTable]] {{.*}} : memref<32xi64> to memref<1x32x1xi64>
25+
26+
// Transform output to [G, seq_q, head_v]
27+
// CHECK-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<32768xf16> to memref<1x1024x32xf16>
28+
29+
// rock.deref: dereference K page table to get actual K data
30+
// CHECK-NEXT: %[[keyDeref:.*]] = rock.deref %[[keysPageTableTransformed]] : memref<1x32x1xi64> -> memref<1x32x1024xf16>
31+
32+
// rock.deref: dereference V page table to get actual V data
33+
// CHECK-NEXT: %[[valueDeref:.*]] = rock.deref %[[valuesPageTableTransformed]] : memref<1x32x1xi64> -> memref<1x32x1024xf16>
34+
35+
// Transform deref'd K to intermediate shapes for attention GEMM
36+
// CHECK-NEXT: %[[keyTransform1:.*]] = rock.transform %[[keyDeref]] {{.*}} : memref<1x32x1024xf16> to memref<1x32768xf16>
37+
// CHECK-NEXT: %[[keyTransform2:.*]] = rock.transform %[[keyTransform1]] {{.*}} : memref<1x32768xf16> to memref<1x1x1024x32xf16>
38+
// CHECK-NEXT: %[[keys:.*]] = rock.transform %[[keyTransform2]] {{.*}} : memref<1x1x1024x32xf16> to memref<1x32x1024xf16>
39+
40+
// Transform deref'd V to intermediate shapes for attention GEMM
41+
// CHECK-NEXT: %[[valueTransform1:.*]] = rock.transform %[[valueDeref]] {{.*}} : memref<1x32x1024xf16> to memref<1x32768xf16>
42+
// CHECK-NEXT: %[[valueTransform2:.*]] = rock.transform %[[valueTransform1]] {{.*}} : memref<1x32768xf16> to memref<1x1x1024x32xf16>
43+
// CHECK-NEXT: %[[values:.*]] = rock.transform %[[valueTransform2]] {{.*}} : memref<1x1x1024x32xf16> to memref<1x1024x32xf16>
44+
45+
// Verify rock.attention op with keyAddresses and valueAddresses attributes
46+
// CHECK-NEXT: rock.attention
47+
// CHECK-NEXT: qk = %[[queries]] * %[[keys]]
48+
// CHECK-NEXT: keyAddresses = (%[[keyDeref]] : memref<1x32x1024xf16>)
49+
// CHECK-NEXT: valueAddresses = (%[[valueDeref]] : memref<1x32x1024xf16>)
50+
// CHECK: %[[output]] = softmax(qk) * %[[values]]
51+
// CHECK: return
52+
53+
// =============================================================================
54+
// CPU host function validation for paged attention
55+
// =============================================================================
56+
57+
// CHECK-LABEL: func.func @host_naive_attention
58+
// CHECK-SAME: (%[[hostQ:.*0]]: memref<32768xf16>,
59+
// CHECK-SAME: %[[hostK:.*1]]: memref<32768xf16>,
60+
// CHECK-SAME: %[[hostV:.*2]]: memref<32768xf16>,
61+
// CHECK-SAME: %[[hostOut:.*3]]: memref<32768xf16>)
62+
63+
// Convert Q memref to tensor and reshape to [1, 1024, 32]
64+
// CHECK: bufferization.to_tensor %[[hostQ]]
65+
// CHECK: tosa.reshape {{.*}} : (tensor<32768xf16>, !tosa.shape<3>) -> tensor<1x1024x32xf16>
66+
67+
// Reshape K to [1, 32, 1024] for Q*K^T matmul
68+
// CHECK: bufferization.to_tensor %[[hostK]]
69+
// CHECK: tosa.reshape {{.*}} : (tensor<32768xf16>, !tosa.shape<3>) -> tensor<1x32x1024xf16>
70+
71+
// Reshape V to [1, 1024, 32]
72+
// CHECK: bufferization.to_tensor %[[hostV]]
73+
// CHECK: tosa.reshape {{.*}} : (tensor<32768xf16>, !tosa.shape<3>) -> tensor<1x1024x32xf16>
74+
75+
// First matmul: Q * K^T -> [1, 1024, 1024]
76+
// CHECK: tosa.matmul {{.*}} {acc_type = f32} : (tensor<1x1024x32xf16>, tensor<1x32x1024xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x1024x1024xf16>
77+
78+
// Softmax: cast to f32, reduce_max, sub, exp, reduce_sum, reciprocal, mul, cast back
79+
// CHECK: tosa.cast {{.*}} : (tensor<1x1024x1024xf16>) -> tensor<1x1024x1024xf32>
80+
// CHECK: tosa.reduce_max {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1xf32>
81+
// CHECK: tosa.sub {{.*}} : (tensor<1x1024x1024xf32>, tensor<1x1024x1xf32>) -> tensor<1x1024x1024xf32>
82+
// CHECK: tosa.exp {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1024xf32>
83+
// CHECK: tosa.reduce_sum {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1xf32>
84+
// CHECK: tosa.reciprocal {{.*}} : (tensor<1x1024x1xf32>) -> tensor<1x1024x1xf32>
85+
// CHECK: tosa.mul {{.*}} : (tensor<1x1024x1024xf32>, tensor<1x1024x1xf32>, tensor<1xi8>) -> tensor<1x1024x1024xf32>
86+
// CHECK: tosa.cast {{.*}} : (tensor<1x1024x1024xf32>) -> tensor<1x1024x1024xf16>
87+
88+
// Second matmul: softmax(Q*K^T) * V -> [1, 1024, 32]
89+
// CHECK: tosa.matmul {{.*}} {acc_type = f32} : (tensor<1x1024x1024xf16>, tensor<1x1024x32xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x1024x32xf16>
90+
91+
// Reshape output and copy to result
92+
// CHECK: tosa.reshape {{.*}} : (tensor<1x1024x32xf16>, !tosa.shape<1>) -> tensor<32768xf16>
93+
// CHECK: bufferization.to_buffer
94+
// CHECK: memref.copy
95+
// CHECK: return
96+
97+
// ----
98+
99+
// Test paged attention with GQA (grouped query attention)
100+
// RUN: rocmlir-gen --arch gfx90a:sramecc+:xnack- --operation attention -num_heads_q 4 -num_heads_kv 2 -seq_len_q 1024 -seq_len_k 1024 -head_dim_qk 32 -head_dim_v 32 -t f16 -pv --apply-bufferization-pipeline=false --paged-attention --num-pages 64 --page-size 1024 | rocmlir-opt | FileCheck %s --enable-var-scope --check-prefixes=CHECK_GQA
101+
102+
// CHECK_GQA: module attributes {mhal.arch = "[[$ARCH:.*]]"}
103+
104+
// Verify GQA paged attention kernel signature:
105+
// - Q input: memref<131072xf16> (flattened [4, 1024, 32])
106+
// - K page table: memref<64xi64> (page pointers for 2 heads)
107+
// - V page table: memref<64xi64> (page pointers for 2 heads)
108+
// - Output: memref<131072xf16> (flattened [4, 1024, 32])
109+
// CHECK_GQA-LABEL: func.func @rock_attention
110+
// CHECK_GQA-SAME: (%[[queriesRaw:.*0]]: memref<131072xf16>,
111+
// CHECK_GQA-SAME: %[[keysPageTable:.*1]]: memref<64xi64>,
112+
// CHECK_GQA-SAME: %[[valuesPageTable:.*2]]: memref<64xi64>,
113+
// CHECK_GQA-SAME: %[[outputRaw:.*3]]: memref<131072xf16>)
114+
// CHECK_GQA-SAME: attributes {kernel, mhal.arch = "[[$ARCH]]"}
115+
116+
// Transform Q to [G, seq_q, head_qk] with G = num_heads_q = 4
117+
// CHECK_GQA-NEXT: %[[queries:.*]] = rock.transform %[[queriesRaw]] {{.*}} : memref<131072xf16> to memref<4x1024x32xf16>
118+
119+
// Transform K page table
120+
// CHECK_GQA-NEXT: %[[keysPageTableTransformed:.*]] = rock.transform %[[keysPageTable]] {{.*}} : memref<64xi64> to memref<1x64x1xi64>
121+
122+
// Transform V page table
123+
// CHECK_GQA-NEXT: %[[valuesPageTableTransformed:.*]] = rock.transform %[[valuesPageTable]] {{.*}} : memref<64xi64> to memref<1x64x1xi64>
124+
125+
// Transform output
126+
// CHECK_GQA-NEXT: %[[output:.*]] = rock.transform %[[outputRaw]] {{.*}} : memref<131072xf16> to memref<4x1024x32xf16>
127+
128+
// rock.deref K
129+
// CHECK_GQA-NEXT: %[[keyDeref:.*]] = rock.deref %[[keysPageTableTransformed]] : memref<1x64x1xi64> -> memref<1x64x1024xf16>
130+
131+
// rock.deref V
132+
// CHECK_GQA-NEXT: %[[valueDeref:.*]] = rock.deref %[[valuesPageTableTransformed]] : memref<1x64x1xi64> -> memref<1x64x1024xf16>
133+
134+
// K transforms to [G, head_dim_qk, seq_k] with G = num_heads_kv = 2
135+
// CHECK_GQA: %[[keys:.*]] = rock.transform %{{.*}} {{.*}} to memref<2x32x1024xf16>
136+
137+
// V transforms to [G, seq_k, head_dim_v] with G = num_heads_kv = 2
138+
// CHECK_GQA: %[[values:.*]] = rock.transform %{{.*}} {{.*}} to memref<2x1024x32xf16>
139+
140+
// Verify rock.attention op with GQA and paged attention
141+
// CHECK_GQA: rock.attention
142+
// CHECK_GQA-NEXT: qk = %[[queries]] * %[[keys]]
143+
// CHECK_GQA-NEXT: keyAddresses = (%[[keyDeref]] : memref<1x64x1024xf16>)
144+
// CHECK_GQA-NEXT: valueAddresses = (%[[valueDeref]] : memref<1x64x1024xf16>)
145+
// CHECK_GQA: %[[output]] = softmax(qk) * %[[values]]
146+
// CHECK_GQA-NEXT: numHeadsKV = 2 : i32, numHeadsQ = 4 : i32
147+
// CHECK_GQA: return

0 commit comments

Comments
 (0)