Skip to content

Commit 8627ec3

Browse files
authored
Merge pull request #88 from JuliaGPU/tb/broadcast
Support broadcasting unsafe_trunc and trunc
2 parents 03583b4 + 51f1fde commit 8627ec3

5 files changed

Lines changed: 179 additions & 18 deletions

File tree

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,17 @@ b = ct.load(B, (expert_id, k, bid_n), (1, TILE_K, TILE_N))
416416

417417
## Differences from Julia
418418

419-
### Float-to-integer conversion truncates
419+
### Some operations are non-throwing
420420

421-
Inside cuTile kernels, `Int32(x::Float32)` and similar float-to-integer constructors
422-
truncate toward zero (like C-style casts), rather than throwing `InexactError` as in
423-
standard Julia. This matches the behavior of GPU hardware and cuTile Python's `ct.astype`.
421+
cuTile kernels cannot throw Julia exceptions. Operations that would throw in
422+
standard Julia silently produce truncated or wrapped results instead:
423+
424+
- **Float-to-integer conversions:** `Int32(x)`, `trunc(Int32, x)`, and
425+
`round(Int32, x, RoundToZero)` silently truncate toward zero rather than
426+
throwing `InexactError` for non-integer or out-of-range values. Use
427+
`unsafe_trunc` for the explicit non-throwing primitive.
428+
429+
Assertions may be added in the future for testing purposes.
424430

425431

426432
## Limitations

src/language/broadcast.jl

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Broadcasting Infrastructure for Tiles
22
#
33
# Defines the broadcast style and shape computation for Tile types.
4-
# All broadcasted operations are materialized via copy → map.
4+
# All broadcasted operations are materialized via copy.
55

66
import Base.Broadcast: BroadcastStyle, Broadcasted, broadcastable, broadcast_shape
77

@@ -24,26 +24,38 @@ Base.Broadcast.broadcastable(t::Tile) = t
2424

2525

2626
#=============================================================================
27-
Broadcast materialization via copy + map
27+
Ghost wrapper for Type values in broadcasting
28+
=============================================================================#
29+
30+
# Replaces Julia's RefValue{Type{T}} wrapping which the cuTile compiler can't construct.
31+
# The value is encoded in the type parameter — no runtime representation needed.
32+
struct TypeRef{T} end
33+
34+
Base.Broadcast.BroadcastStyle(::Type{<:TypeRef}) = Base.Broadcast.DefaultArrayStyle{0}()
35+
Base.Broadcast.broadcastable(a::TypeRef) = a
36+
37+
38+
#=============================================================================
39+
Broadcast materialization via copy
2840
=============================================================================#
2941

3042
# Tile is a ghost type with no storage, so axes/size are meaningless.
3143
# Skip instantiate (which calls axes) by returning the Broadcasted as-is.
3244
@inline Base.Broadcast.instantiate(bc::Broadcasted{TileStyle}) = bc
3345

3446
# Recursively materialize nested Broadcasted nodes,
35-
# promote scalars to Tiles, broadcast to a common shape, then apply via map.
47+
# promote scalars to Tiles, broadcast to a common shape, then apply f.
3648
# This handles all element-wise operations: scalar @overlay methods provide
3749
# the implementation for overlaid ops, while Julia's native scalar functions
3850
# (compiled to Core intrinsics) handle the rest. Mixed-type and type-changing
3951
# operations (comparisons, ifelse) are supported by the mixed-type map methods
4052
# in operations.jl.
4153
@inline function Base.copy(bc::Broadcasted{TileStyle})
4254
args = _materialize_args(bc.args)
43-
tiles = _promote_to_tiles(args...)
44-
S = _broadcast_shapes(tiles...)
45-
broadcasted = _broadcast_all(S, tiles...)
46-
map(bc.f, broadcasted...)
55+
promoted = _promote_to_tiles(args...)
56+
S = _broadcast_shapes(promoted...)
57+
broadcasted = _broadcast_all(S, promoted...)
58+
_apply_broadcast(bc.f, broadcasted...)
4759
end
4860

4961
# Recursively materialize nested Broadcasted nodes into concrete Tiles.
@@ -63,19 +75,46 @@ end
6375
# using its own type (e.g., 0.0f0 → Tile(Float32(0.0))), preserving the
6476
# type that Julia's broadcast promotion chose. This avoids the pitfall of
6577
# using the first Tile's eltype (which could be Bool for ifelse conditions).
78+
# TypeRef arguments pass through unchanged — they carry no tile shape.
6679
@inline _promote_to_tiles() = ()
6780
@inline _promote_to_tiles(a::Tile, rest...) = (a, _promote_to_tiles(rest...)...)
6881
@inline _promote_to_tiles(a::T, rest...) where {T <: Number} =
6982
(Tile(a), _promote_to_tiles(rest...)...)
83+
@inline _promote_to_tiles(a::TypeRef, rest...) = (a, _promote_to_tiles(rest...)...)
7084

7185
# Compute combined broadcast shape across all Tile arguments via tuple peeling.
7286
# Shape is always a tuple TYPE (e.g., Tuple{16, 32}). Convert to value for broadcast_shape.
87+
# TypeRef arguments are skipped — they have no shape.
7388
@inline _tile_shape(t::Tile) = size(t)
7489
@inline _broadcast_shapes(t::Tile) = _tile_shape(t)
75-
@inline _broadcast_shapes(t::Tile, rest::Tile...) =
90+
@inline _broadcast_shapes(t::Tile, rest...) =
7691
broadcast_shape(_tile_shape(t), _broadcast_shapes(rest...))
92+
@inline _broadcast_shapes(::TypeRef, rest...) = _broadcast_shapes(rest...)
93+
@inline _broadcast_shapes(::TypeRef) = ()
7794

7895
# Broadcast all tiles to shape S via tuple peeling.
96+
# TypeRef arguments pass through unchanged.
7997
@inline _broadcast_all(S::Tuple) = ()
80-
@inline _broadcast_all(S::Tuple, a::Tile, rest::Tile...) =
98+
@inline _broadcast_all(S::Tuple, a::Tile, rest...) =
8199
(broadcast_to(a, S), _broadcast_all(S, rest...)...)
100+
@inline _broadcast_all(S::Tuple, a::TypeRef, rest...) =
101+
(a, _broadcast_all(S, rest...)...)
102+
103+
# Convert args to scalars, apply f, wrap result back into a Tile.
104+
@inline function _apply_broadcast(f, args...)
105+
scalar_args, S = _to_scalars(args...)
106+
Intrinsics.from_scalar(f(scalar_args...), S)
107+
end
108+
109+
# Reinterpret Tile arguments as scalars for broadcast application.
110+
# Skip and extract TypeRef arguments.
111+
# Returns (scalar_args_tuple, S) where S is the shape from the first Tile.
112+
@inline _to_scalars(t::Tile{<:Any,S}) where S = ((Intrinsics.to_scalar(t),), S)
113+
@inline function _to_scalars(t::Tile{<:Any,S}, rest...) where S
114+
rest_scalars, _ = _to_scalars(rest...)
115+
((Intrinsics.to_scalar(t), rest_scalars...), S)
116+
end
117+
@inline function _to_scalars(::TypeRef{T}, rest...) where T
118+
rest_scalars, S = _to_scalars(rest...)
119+
((T, rest_scalars...), S)
120+
end

src/language/overlays.jl

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ macro overlay(ex)
88
end
99

1010

11+
#=============================================================================
12+
Broadcasting
13+
=============================================================================#
14+
15+
# Route Type values through TypeRef instead of RefValue (which can't be constructed in Tile IR).
16+
@overlay Base.Broadcast.broadcastable(::Type{T}) where T = TypeRef{T}()
17+
18+
1119
#=============================================================================
1220
Type Conversions
1321
=============================================================================#
@@ -51,13 +59,13 @@ end
5159
sizeof(S) > sizeof(T) ? Intrinsics.exti(x, S, SignednessUnsigned) :
5260
sizeof(S) < sizeof(T) ? Intrinsics.trunci(x, S) : x
5361

54-
# Float to float (specific type pairs)
62+
# Float to float
5563
for T in Floats, S in Floats
5664
T === S && continue
5765
@eval @overlay $T(x::$S) = Intrinsics.ftof(x, $T)
5866
end
5967

60-
# Integer to float (specific type pairs)
68+
# Integer to float
6169
for F in Floats
6270
for I in SignedInts
6371
@eval @overlay $F(x::$I) = Intrinsics.itof(x, $F, SignednessSigned)
@@ -78,12 +86,26 @@ for F in Floats
7886
end
7987
end
8088

81-
# Float to integer (direct constructor - truncates like C-style cast)
89+
# Float to integer (round with RoundToZero)
90+
for F in Floats, I in (SignedInts..., UnsignedInts...)
91+
@eval @overlay function Base.round(::Type{$I}, x::$F, ::Base.Rounding.RoundingMode{:ToZero})
92+
# TODO: assert that x is within bounds etc
93+
unsafe_trunc($I, x)
94+
end
95+
end
96+
97+
# Float to integer (direct constructor)
8298
for F in Floats
8399
for I in SignedInts
84-
@eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessSigned)
100+
@eval @overlay function $I(x::$F)
101+
# TODO: assert that x is within bounds etc
102+
unsafe_trunc($I, x)
103+
end
85104
end
86105
for I in UnsignedInts
87-
@eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessUnsigned)
106+
@eval @overlay function $I(x::$F)
107+
# TODO: assert that x is within bounds etc
108+
unsafe_trunc($I, x)
109+
end
88110
end
89111
end

test/codegen/operations.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,45 @@
849849
end
850850
end
851851
end
852+
853+
@testset "Type broadcasting" begin
854+
# convert.(Float16, tile) — Type arg via TypeRef
855+
@test @filecheck begin
856+
@check_label "entry"
857+
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float16,1,spec1d}}) do a, b
858+
pid = ct.bid(1)
859+
tile = ct.load(a, pid, (16,))
860+
@check "ftof"
861+
ct.store(b, pid, convert.(Float16, tile))
862+
return
863+
end
864+
end
865+
866+
# convert.(Float32, float16_tile) — upcast via Type arg
867+
@test @filecheck begin
868+
@check_label "entry"
869+
code_tiled(Tuple{ct.TileArray{Float16,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do a, b
870+
pid = ct.bid(1)
871+
tile = ct.load(a, pid, (16,))
872+
@check "ftof"
873+
ct.store(b, pid, convert.(Float32, tile))
874+
return
875+
end
876+
end
877+
878+
# unsafe_trunc.(Int32, float32_tile) — ftoi via Type arg
879+
@test @filecheck begin
880+
@check_label "entry"
881+
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Int32,1,spec1d}}) do a, b
882+
pid = ct.bid(1)
883+
tile = ct.load(a, pid, (16,))
884+
@check "ftoi"
885+
ct.store(b, pid, unsafe_trunc.(Int32, tile))
886+
return
887+
end
888+
end
889+
890+
end
852891
end
853892

854893
#=========================================================================

test/execution/broadcast.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,61 @@ end
685685

686686
end # fma broadcasting
687687

688+
@testset "type argument broadcasting" begin
689+
690+
@testset "convert.(Float16, tile)" begin
691+
function convert_f16_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float16,1})
692+
pid = ct.bid(1)
693+
tile = ct.load(a, pid, (16,))
694+
ct.store(b, pid, convert.(Float16, tile))
695+
return
696+
end
697+
698+
n = 1024
699+
a = CUDA.rand(Float32, n)
700+
b = CUDA.zeros(Float16, n)
701+
702+
ct.launch(convert_f16_kernel, cld(n, 16), a, b)
703+
704+
@test Array(b) == Float16.(Array(a))
705+
end
706+
707+
@testset "convert.(Float32, float16_tile)" begin
708+
function convert_f32_kernel(a::ct.TileArray{Float16,1}, b::ct.TileArray{Float32,1})
709+
pid = ct.bid(1)
710+
tile = ct.load(a, pid, (16,))
711+
ct.store(b, pid, convert.(Float32, tile))
712+
return
713+
end
714+
715+
n = 1024
716+
a = CUDA.rand(Float16, n)
717+
b = CUDA.zeros(Float32, n)
718+
719+
ct.launch(convert_f32_kernel, cld(n, 16), a, b)
720+
721+
@test Array(b) == Float32.(Array(a))
722+
end
723+
724+
@testset "unsafe_trunc.(Int32, float_tile)" begin
725+
function unsafe_trunc_i32_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Int32,1})
726+
pid = ct.bid(1)
727+
tile = ct.load(a, pid, (16,))
728+
ct.store(b, pid, unsafe_trunc.(Int32, tile))
729+
return
730+
end
731+
732+
n = 1024
733+
a = CuArray(Float32.(rand(-100:100, n)) .+ 0.7f0)
734+
b = CUDA.zeros(Int32, n)
735+
736+
ct.launch(unsafe_trunc_i32_kernel, cld(n, 16), a, b)
737+
738+
@test Array(b) == unsafe_trunc.(Int32, Array(a))
739+
end
740+
741+
end # type argument broadcasting
742+
688743
@testset "multi-arg map" begin
689744
@testset "binary map(+, ...)" begin
690745
function map_add_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},

0 commit comments

Comments
 (0)