Skip to content

Commit 51f1fde

Browse files
committed
Support type ctors/trunc/round through unsafe_trunc.
1 parent 6bcd51c commit 51f1fde

4 files changed

Lines changed: 58 additions & 9 deletions

File tree

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,17 @@ b = ct.load(B, (expert_id, k, bid_n), (1, TILE_K, TILE_N))
416416

417417
## Differences from Julia
418418

419-
### Float-to-integer conversion truncates
419+
### Some operations are non-throwing
420420

421-
Inside cuTile kernels, `Int32(x::Float32)` and similar float-to-integer constructors
422-
truncate toward zero (like C-style casts), rather than throwing `InexactError` as in
423-
standard Julia. This matches the behavior of GPU hardware and cuTile Python's `ct.astype`.
421+
cuTile kernels cannot throw Julia exceptions. Operations that would throw in
422+
standard Julia silently produce truncated or wrapped results instead:
423+
424+
- **Float-to-integer conversions:** `Int32(x)`, `trunc(Int32, x)`, and
425+
`round(Int32, x, RoundToZero)` silently truncate toward zero rather than
426+
throwing `InexactError` for non-integer or out-of-range values. Use
427+
`unsafe_trunc` for the explicit non-throwing primitive.
428+
429+
Assertions may be added in the future for testing purposes.
424430

425431

426432
## Limitations

src/language/overlays.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ end
5959
sizeof(S) > sizeof(T) ? Intrinsics.exti(x, S, SignednessUnsigned) :
6060
sizeof(S) < sizeof(T) ? Intrinsics.trunci(x, S) : x
6161

62-
# Float to float (specific type pairs)
62+
# Float to float
6363
for T in Floats, S in Floats
6464
T === S && continue
6565
@eval @overlay $T(x::$S) = Intrinsics.ftof(x, $T)
6666
end
6767

68-
# Integer to float (specific type pairs)
68+
# Integer to float
6969
for F in Floats
7070
for I in SignedInts
7171
@eval @overlay $F(x::$I) = Intrinsics.itof(x, $F, SignednessSigned)
@@ -86,12 +86,26 @@ for F in Floats
8686
end
8787
end
8888

89-
# Float to integer (direct constructor - truncates like C-style cast)
89+
# Float to integer (round with RoundToZero)
90+
for F in Floats, I in (SignedInts..., UnsignedInts...)
91+
@eval @overlay function Base.round(::Type{$I}, x::$F, ::Base.Rounding.RoundingMode{:ToZero})
92+
# TODO: assert that x is within bounds etc
93+
unsafe_trunc($I, x)
94+
end
95+
end
96+
97+
# Float to integer (direct constructor)
9098
for F in Floats
9199
for I in SignedInts
92-
@eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessSigned)
100+
@eval @overlay function $I(x::$F)
101+
# TODO: assert that x is within bounds etc
102+
unsafe_trunc($I, x)
103+
end
93104
end
94105
for I in UnsignedInts
95-
@eval @overlay $I(x::$F) = Intrinsics.ftoi(x, $I, SignednessUnsigned)
106+
@eval @overlay function $I(x::$F)
107+
# TODO: assert that x is within bounds etc
108+
unsafe_trunc($I, x)
109+
end
96110
end
97111
end

test/codegen/operations.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,18 @@
875875
end
876876
end
877877

878+
# unsafe_trunc.(Int32, float32_tile) — ftoi via Type arg
879+
@test @filecheck begin
880+
@check_label "entry"
881+
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Int32,1,spec1d}}) do a, b
882+
pid = ct.bid(1)
883+
tile = ct.load(a, pid, (16,))
884+
@check "ftoi"
885+
ct.store(b, pid, unsafe_trunc.(Int32, tile))
886+
return
887+
end
888+
end
889+
878890
end
879891
end
880892

test/execution/broadcast.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,23 @@ end
721721
@test Array(b) == Float32.(Array(a))
722722
end
723723

724+
@testset "unsafe_trunc.(Int32, float_tile)" begin
725+
function unsafe_trunc_i32_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Int32,1})
726+
pid = ct.bid(1)
727+
tile = ct.load(a, pid, (16,))
728+
ct.store(b, pid, unsafe_trunc.(Int32, tile))
729+
return
730+
end
731+
732+
n = 1024
733+
a = CuArray(Float32.(rand(-100:100, n)) .+ 0.7f0)
734+
b = CUDA.zeros(Int32, n)
735+
736+
ct.launch(unsafe_trunc_i32_kernel, cld(n, 16), a, b)
737+
738+
@test Array(b) == unsafe_trunc.(Int32, Array(a))
739+
end
740+
724741
end # type argument broadcasting
725742

726743
@testset "multi-arg map" begin

0 commit comments

Comments
 (0)