Skip to content

Commit 3fe82c6

Browse files
authored
Fix FMHA index order and check (#231)
1 parent 70f3731 commit 3fe82c6

3 files changed

Lines changed: 7 additions & 5 deletions

File tree

examples/fmha.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function fmha_kernel(Q::ct.TileArray{T, 4}, K::ct.TileArray{T, 4},
8181
# QK product
8282
# K is (D_k, SeqLen_KV, KVH, Batch)
8383
# Load (TILE_N, TILE_D, 1, 1) with order=(2,1,3,4) to transpose D and N
84-
k = reshape(ct.load(K; index=(Int32(1), j + Int32(1), off_kv_h, batch_idx),
84+
k = reshape(ct.load(K; index=(j + Int32(1), Int32(1), off_kv_h, batch_idx),
8585
shape=(TILE_N, TILE_D, 1, 1), order=(2, 1, 3, 4),
8686
latency=2),
8787
(TILE_N, TILE_D)) # (TILE_N, TILE_D)
@@ -279,7 +279,7 @@ function verify(data, result)
279279
expected = ref_fmha(data.Q, data.K, data.V; causal=data.causal)
280280
actual = Float32.(Array(result.out))
281281
max_diff = maximum(abs.(actual .- expected))
282-
@assert max_diff < 1e-2 "FMHA incorrect! max diff: $max_diff"
282+
@assert isapprox(actual, expected, rtol=1e-2, atol=1e-3) "FMHA incorrect! max diff: $max_diff"
283283
end
284284

285285

examples/fmha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def verify(data, result):
263263
expected = ref_fmha(data["Q"], data["K"], data["V"],
264264
causal=data["causal"])
265265
actual = cp.asnumpy(result["out"]).astype(np.float32)
266-
assert np.allclose(actual, expected, rtol=1e-2, atol=1e-2), \
266+
assert np.allclose(actual, expected, rtol=1e-2, atol=1e-3), \
267267
f"FMHA incorrect! max diff: {np.max(np.abs(actual - expected))}"
268268

269269

src/language/operations.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ Index is 1-indexed. Shape must be compile-time constant.
345345
- `order`: Optional tuple specifying the logical-to-physical dimension mapping (1-indexed).
346346
For example, `order=(2, 1)` indicates dimension 2 is contiguous in memory,
347347
enabling coalesced loads from transposed/permuted arrays.
348+
`index[i]` and `shape[i]` describe tile dim `i`, which maps to source dim `order[i]`.
348349
Default: `nothing` → identity `(1, 2, ..., N)`.
349350
350351
# Padding Modes
@@ -366,10 +367,10 @@ outside the array, the behavior is undefined regardless of `padding_mode`.
366367
367368
# Example
368369
```julia
369-
tile = ct.load(arr, (bid,), (TILE_N[],); padding_mode=ct.PaddingMode.Zero, latency=3)
370+
tile = ct.load(arr, (bid,), (TILE_N,); padding_mode=ct.PaddingMode.Zero, latency=3)
370371
371372
# Load from a transposed array with coalesced access
372-
tile = ct.load(arr, (bidx, bidy), (TM, TN); order=(2, 1))
373+
tile = ct.load(arr, (bidy, bidx), (TN, TM); order=(2, 1))
373374
```
374375
"""
375376
@inline function load(arr::TileArray, index, shape::NTuple{<:Any, Int};
@@ -442,6 +443,7 @@ behavior is undefined.
442443
# Dimension Ordering
443444
- `order`: Optional tuple specifying the logical-to-physical dimension mapping (1-indexed).
444445
Must match the `order` used in the corresponding `load` for permuted arrays.
446+
`index[i]` and `shape[i]` describe tile dim `i`, which maps to destination dim `order[i]`.
445447
Default: `nothing` → identity `(1, 2, ..., N)`.
446448
447449
# Optimization Hints

0 commit comments

Comments
 (0)