Skip to content

Add Fused Multi-Head Attention example#16

Closed
AntonOresten wants to merge 13 commits into
JuliaGPU:mainfrom
AntonOresten:fmha
Closed

Add Fused Multi-Head Attention example#16
AntonOresten wants to merge 13 commits into
JuliaGPU:mainfrom
AntonOresten:fmha

Conversation

@AntonOresten

@AntonOresten AntonOresten commented Jan 10, 2026

Copy link
Copy Markdown
Collaborator
See outdated

Seems to fall slightly short of my NNop / ONIONop baseline (no WMMA), although I haven't compared it to the Python version. On my GPU, it compiles and runs fastest with tile_n=32 and tile_m=32:

julia> begin
           T = Float32
           D, QL, KL, H, B = 64, 4096, 4096, 4, 4
           q = CUDA.randn(T, D, QL, H, B)
           k = CUDA.randn(T, D, KL, H, B)  
           v = CUDA.randn(T, D, KL, H, B)
       end;

julia> @b CUDA.@sync ONIONop.flash_attention(q, k, v, causal=false)
9.559 ms (339 allocs: 7.875 KiB)

julia> @b CUDA.@sync cutile_fmha(q, k, v, causal=false, tile_m=32, tile_n=32)
11.058 ms (540 allocs: 23.109 KiB)

EDIT: this is without tensor cores. simply switching the compute type to TFloat32 / BFloat16 and exploring the optimization and entry hint landscape makes forward and backward passes ~10x faster.

Notably, cutile-python has a latency argument for ct.load, as well as num_ctas and occupancy arguments for the kernel, which might affect performance. The python version also does a kernel config autotune by searching a space of hand-picked configurations. EDIT: fixed in #32 and #27.

Another thing that might be important for correctness or covering edge cases is exposing flush_to_zero? Used in e.g. exp2. Haven't thought about in which cases this matters.

@AntonOresten

AntonOresten commented Jan 17, 2026

Copy link
Copy Markdown
Collaborator Author

Seeing some weird erroring when branching (being fixed in #53):

Click to see snippets

This works:

        qk = if !EVEN_K[] && j >= mask_start
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            mask = mask .& (offs_n .<= k_seqlen)
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk .+ mask
        else
            qk
        end

but this doesn't:

        if !EVEN_K[] && j >= mask_start
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            mask = mask .& (offs_n .<= k_seqlen)
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk = qk .+ mask
        end

nor does this:

        qk = if !EVEN_K[] && j >= mask_start
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            if !EVEN_K[]
                mask .& (offs_n .<= k_seqlen)
            end
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk .+ mask
        else
            qk
        end

In the second and third block, I get "ERROR: SSAValue %___ not found in context"

after removing the second condition, I can suddenly have a nested if block, and I don't need the outer else block:

        if !EVEN_K[]
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            if !EVEN_K[]
                mask = mask .& (offs_n .<= k_seqlen)
            end
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk = qk .+ mask
        end

Does the if block need to depend on compile time constants?

I'd need this to make the padding and causal mask properly.

@maleadt

maleadt commented Jan 19, 2026

Copy link
Copy Markdown
Member

In the second and third block, I get "ERROR: SSAValue %___ not found in context"

That's an IRStructurizer error. Can you provide an MWE?

@AntonOresten

Copy link
Copy Markdown
Collaborator Author

That's an IRStructurizer error. Can you provide an MWE?

I was able to narrow it down and believe it is covered by #53. See the added tests for MWE.

With #51 and #53, I now have forward and backward passes working locally!

@AntonOresten AntonOresten marked this pull request as ready for review February 5, 2026 18:28
@AntonOresten

AntonOresten commented Feb 5, 2026

Copy link
Copy Markdown
Collaborator Author

Currently needing to wrap outside Float32 constants in Float32 within the kernel because MulF otherwise sees it as nothing:

qk_scale = Float32(qk_scale) * Float32(INV_LOG_2)

@AntonOresten

Copy link
Copy Markdown
Collaborator Author

Another concern is whether I should convert to Int32 or not, essentially every time I do index arithmetic.

@maleadt

maleadt commented Feb 6, 2026

Copy link
Copy Markdown
Member

Currently needing to wrap outside Float32 constants in Float32 within the kernel because MulF otherwise sees it as nothing:

qk_scale = Float32(qk_scale) * Float32(INV_LOG_2)

Can you elaborate?

Another concern is whether I should convert to Int32 or not, essentially every time I do index arithmetic.

Yeah that's a common Julia pain point. It's why we have One(), and in CUDA.jl you can do e.g. 1i32.

@AntonOresten

Copy link
Copy Markdown
Collaborator Author

Can you elaborate?

I define const INV_LOG_2 = Float32(1 / log(2)). If I use it without wrapping in Float32 within the kernel I get:

ERROR: LoadError: MethodError: no method matching encode_MulFOp!(::cuTile.CodeBuilder, ::cuTile.TypeId, ::cuTile.Value, ::Nothing)
The function `encode_MulFOp!` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  encode_MulFOp!(::cuTile.CodeBuilder, ::cuTile.TypeId, ::cuTile.Value, ::cuTile.Value; rounding_mode, flush_to_zero)
   @ cuTile ~/.julia/dev/cuTile/src/bytecode/encodings.jl:720

in CUDA.jl you can do e.g. 1i32.

Oh, neat. I didn't know. I considered maybe a @32 macro (macro var"32" ... end) to make all 64-bit integers wrapped in their 32-bit counterparts. Found that it wouldn't work within curly brackets like for type parameters since e.g. Array{T,Int32(1)} won't count as a vector, but the macro doesn't need to descend into :curly expressions. Problem is still that some functions actually only have methods for Int so it can't be applied to the entire function.

@maleadt

maleadt commented Feb 6, 2026

Copy link
Copy Markdown
Member

Problem is still that some functions actually only have methods for Int so it can't be applied to the entire function.

In general, Julia's array indexing requires Int. In CUDA.jl we've added some additional methods to override part of the getindex chain to support Int32, but it's tricky...

@maleadt

maleadt commented Feb 8, 2026

Copy link
Copy Markdown
Member

Constants should work without the type conversion now.

@AntonOresten

Copy link
Copy Markdown
Collaborator Author

See #77 (comment)

@AntonOresten

AntonOresten commented Feb 10, 2026

Copy link
Copy Markdown
Collaborator Author

See #40 (comment). Required for full parity with the python example which has a function that uses autotuned launch.

@AntonOresten AntonOresten marked this pull request as draft February 21, 2026 08:19
@AntonOresten

Copy link
Copy Markdown
Collaborator Author

Superceded by #170

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants