Skip to content

Commit 532bcc9

Browse files
Support Float8 types through DLFP8Types.jl (#36)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 6010752 commit 532bcc9

7 files changed

Lines changed: 56 additions & 3 deletions

File tree

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "cuTile"
22
uuid = "0dea8319-8c4a-4662-a73d-20234d115b9a"
3-
authors = ["Tim Besard <tim.besard@gmail.com>"]
43
version = "0.1.0"
4+
authors = ["Tim Besard <tim.besard@gmail.com>"]
55

66
[deps]
77
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
@@ -12,13 +12,15 @@ IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93"
1212

1313
[weakdeps]
1414
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
15+
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
1516

1617
[sources]
1718
CompilerCaching = {url = "https://github.com/maleadt/CompilerCaching.jl", rev="main"}
1819
IRStructurizer = {url = "https://github.com/maleadt/IRStructurizer.jl", rev = "main"}
1920

2021
[extensions]
2122
CUDAExt = "CUDA"
23+
DLFP8TypesExt = "DLFP8Types"
2224

2325
[compat]
2426
julia = "1.11"

ext/DLFP8TypesExt.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module DLFP8TypesExt
2+
3+
import cuTile as ct
4+
5+
using DLFP8Types: Float8_E4M3FN, Float8_E5M2
6+
7+
function ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E4M3FN})
8+
return ct.F8E4M3FN(table)
9+
end
10+
11+
function ct.julia_to_tile_dtype!(table::ct.TypeTable, ::Type{Float8_E5M2})
12+
return ct.F8E5M2(table)
13+
end
14+
15+
end

src/bytecode/types.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ BF16(table::TypeTable) = simple_type!(table, SimpleType.BF16)
127127
F32(table::TypeTable) = simple_type!(table, SimpleType.F32)
128128
TF32(table::TypeTable) = simple_type!(table, SimpleType.TF32)
129129
F64(table::TypeTable) = simple_type!(table, SimpleType.F64)
130+
F8E4M3FN(table::TypeTable) = simple_type!(table, SimpleType.F8E4M3FN)
131+
F8E5M2(table::TypeTable) = simple_type!(table, SimpleType.F8E5M2)
130132
Token(table::TypeTable) = simple_type!(table, SimpleType.Token)
131133

132134
function tile_type!(table::TypeTable, dtype::TypeId, shape::AbstractVector{<:Integer})

src/compiler/codegen/values.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ function constant_to_bytes(@nospecialize(value), @nospecialize(T::Type))
146146
return collect(reinterpret(UInt8, [Int32(value)]))
147147
elseif T === Int64 || T === UInt64
148148
return collect(reinterpret(UInt8, [Int64(value)]))
149+
elseif T === Float16
150+
return collect(reinterpret(UInt8, [Float16(value)]))
149151
elseif T === Float32
150152
return collect(reinterpret(UInt8, [Float32(value)]))
151153
elseif T === Float64

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
cuTile = "0dea8319-8c4a-4662-a73d-20234d115b9a"
33
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4+
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
45
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
56
FileCheck = "4e644321-382b-4b05-b0b6-5d23c3d944fb"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/ext/DLFP8TypesExt.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using DLFP8Types: Float8_E4M3FN, Float8_E5M2
2+
3+
@testset "DLFP8Types extension" begin
4+
5+
spec1d = ct.ArraySpec{1}(16, true)
6+
7+
# Float32 -> Float8_E4M3FN
8+
@test @filecheck begin
9+
@check_label "entry"
10+
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
11+
pid = ct.bid(1)
12+
tile = ct.load(a, pid, (16,))
13+
@check "ftof"
14+
converted = convert(ct.Tile{Float8_E4M3FN}, tile)
15+
ct.store(b, pid, ct.astype(converted, Float32))
16+
return
17+
end
18+
end
19+
20+
# Float32 -> Float8_E5M2
21+
@test @filecheck begin
22+
@check_label "entry"
23+
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
24+
pid = ct.bid(1)
25+
tile = ct.load(a, pid, (16,))
26+
@check "ftof"
27+
converted = convert(ct.Tile{Float8_E5M2}, tile)
28+
ct.store(b, pid, ct.astype(converted, Float32))
29+
return
30+
end
31+
end
32+
33+
end

test/runtests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ end
4141
# Only include executing tests when CUDA is functional
4242
args = parse_args(ARGS)
4343
if filter_tests!(testsuite, args)
44-
delete!(testsuite, "filecheck")
45-
4644
cuda_functional = CUDA.functional()
4745
filter!(testsuite) do (test, _)
4846
if in(test, ["execution"]) || startswith(test, "examples/")

0 commit comments

Comments
 (0)