Skip to content

Commit b7c28aa

Browse files
0xtaruhimaleadtclaude
authored
Support runtime values in ct.full (Intrinsics.constant) (#100)
Co-authored-by: Tim Besard <tim.besard@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e93c641 commit b7c28aa

5 files changed

Lines changed: 48 additions & 8 deletions

File tree

src/compiler/intrinsics/core.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,22 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.constant), args)
181181
tile_shape = collect(Int, shape)
182182
validate_tile_shape(tile_shape, "full")
183183

184-
# Extract value
185-
value = @something get_constant(ctx, args[2]) throw(IRError("full() value must be a compile-time constant"))
186-
187184
# Extract dtype from Type{T} argument
188185
elem_type = @something get_constant(ctx, args[3]) throw(IRError("constant() requires a compile-time element type"))
189186

190187
dtype = julia_to_tile_dtype!(tt, elem_type)
191188
tile_type = tile_type!(tt, dtype, tile_shape)
192189

193-
# Create constant directly at target shape
194-
value_bytes = constant_to_bytes(value, elem_type)
195-
result = encode_ConstantOp!(cb, tile_type, value_bytes)
190+
tv = emit_value!(ctx, args[2])
191+
tv === nothing && throw(IRError("full() value must be a constant or a runtime scalar"))
192+
if tv.constant !== nothing
193+
# Compile-time constant: use ConstantOp directly
194+
value_bytes = constant_to_bytes(something(tv.constant), elem_type)
195+
result = encode_ConstantOp!(cb, tile_type, value_bytes)
196+
else
197+
# Runtime value: broadcast 0D tile to the target shape
198+
result = broadcast_tile_to_shape!(cb, tt, tv, tile_shape, dtype)
199+
end
196200

197201
CGVal(result, tile_type, Tile{elem_type, Tuple{tile_shape...}}, tile_shape)
198202
end

src/compiler/intrinsics/julia.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,5 @@ function emit_intrinsic!(ctx::CGCtx, func::Type{<:Tile}, args)
6464

6565
# Return as 0D tile type with element type from the constructor
6666
result_jltype = Tile{elem_type, Tuple{}}
67-
CGVal(source.v, source.type_id, result_jltype, source.shape)
67+
CGVal(source.v, source.type_id, result_jltype, source.shape, nothing, source.constant, nothing)
6868
end

src/language/operations.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,10 @@ Create a tile filled with a constant value.
434434
ones_tile = ct.full((32, 32), 1.0f0, Float32)
435435
```
436436
"""
437+
@inline full(shape::NTuple{N, Int}, value::Tile, ::Type{T}) where {N, T} =
438+
Intrinsics.constant(shape, convert(Tile{T}, value), T)
437439
@inline full(shape::NTuple{N, Int}, value, ::Type{T}) where {N, T} =
438-
Intrinsics.constant(shape, value, T)
440+
Intrinsics.constant(shape, Tile(T(value)), T)
439441

440442
"""
441443
zeros(shape::NTuple{N, Int}, dtype::Type{T}) -> Tile{T, shape}

test/codegen/operations.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,21 @@
428428
end
429429
end
430430

431+
@testset "constant with runtime value" begin
432+
@test @filecheck begin
433+
@check_label "entry"
434+
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, Int32}) do a, val
435+
pid = ct.bid(1)
436+
@check "itof"
437+
@check "reshape"
438+
@check "broadcast"
439+
tile = ct.full((16,), val, Float32)
440+
ct.store(a, pid, tile)
441+
return
442+
end
443+
end
444+
end
445+
431446
@testset "get_num_tile_blocks" begin
432447
@test @filecheck begin
433448
@check_label "entry"

test/execution/basic.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,25 @@ const _EXEC_TEST_GLOBAL_CONST = Float32(1 / log(2))
10581058
@test Array(b) Array(a) .* (scale * _EXEC_TEST_GLOBAL_CONST)
10591059
end
10601060

1061+
@testset "full with runtime value" begin
1062+
function full_runtime_kernel(src::ct.TileArray{Float32,1}, dst::ct.TileArray{Float32,1})
1063+
# Load a single-element tile to get a runtime scalar
1064+
val = ct.load(src, 1, (1,))
1065+
pid = ct.bid(1)
1066+
tile = ct.full((16,), val, Float32)
1067+
ct.store(dst, pid, tile)
1068+
return
1069+
end
1070+
1071+
n = 1024
1072+
src = CUDA.fill(3.14f0, 1)
1073+
dst = CUDA.zeros(Float32, n)
1074+
1075+
ct.launch(full_runtime_kernel, cld(n, 16), src, dst)
1076+
1077+
@test all(Array(dst) .≈ 3.14f0)
1078+
end
1079+
10611080
@testset "kernel name with !" begin
10621081
function kernel!()
10631082
return

0 commit comments

Comments
 (0)