Skip to content

Commit 19e3eb3

Browse files
authored
Allow early returns (#92)
1 parent b7c28aa commit 19e3eb3

3 files changed

Lines changed: 119 additions & 0 deletions

File tree

src/compiler/codegen/control_flow.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,59 @@ function emit_terminator!(ctx::CGCtx, ::ConditionOp)
453453
# ConditionOp is handled specially by emit_while_op!, not emitted as a terminator
454454
end
455455

456+
#=============================================================================
457+
Early Return Hoisting
458+
459+
tileiras rejects ReturnNode (cuda_tile.return) inside IfOp (cuda_tile.if)
460+
regions. This pre-pass rewrites the structured IR so that ReturnNode only
461+
appears at the top level, replacing nested returns with YieldOp.
462+
=============================================================================#
463+
464+
"""
465+
hoist_returns!(block::Block)
466+
467+
Rewrite `ReturnNode` terminators inside `IfOp` regions into `YieldOp`,
468+
hoisting the return to the parent block. Operates recursively so that
469+
nested early returns (multiple successive `if ... return end` patterns)
470+
are handled automatically.
471+
472+
Only handles the case where BOTH branches of an IfOp terminate with
473+
ReturnNode (REGION_TERMINATION with 3 children). The 2-child case
474+
(early return inside a loop) is not handled.
475+
"""
476+
function hoist_returns!(block::Block)
477+
# First, recurse into all nested control flow ops
478+
for (_, entry) in block.body
479+
stmt = entry.stmt
480+
if stmt isa IfOp
481+
hoist_returns!(stmt.then_region)
482+
hoist_returns!(stmt.else_region)
483+
elseif stmt isa ForOp
484+
hoist_returns!(stmt.body)
485+
elseif stmt isa WhileOp
486+
hoist_returns!(stmt.before)
487+
hoist_returns!(stmt.after)
488+
elseif stmt isa LoopOp
489+
hoist_returns!(stmt.body)
490+
end
491+
end
492+
493+
# Now check: does this block contain an IfOp where both branches return?
494+
# If so, replace branch ReturnNodes with YieldOp and set block terminator.
495+
for (_, entry) in block.body
496+
entry.stmt isa IfOp || continue
497+
op = entry.stmt::IfOp
498+
op.then_region.terminator isa ReturnNode || continue
499+
op.else_region.terminator isa ReturnNode || continue
500+
501+
# Both branches return — hoist to parent block.
502+
# Replace branch terminators with YieldOp (void — no values to yield).
503+
op.then_region.terminator = YieldOp()
504+
op.else_region.terminator = YieldOp()
505+
block.terminator = ReturnNode(nothing)
506+
end
507+
end
508+
456509
"""
457510
emit_getfield!(ctx, args) -> Union{CGVal, Nothing}
458511

src/compiler/codegen/kernel.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
157157
ctx.token_type = token_type
158158
ctx.token = encode_MakeTokenOp!(cb, token_type)
159159

160+
# Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp)
161+
hoist_returns!(ctx.sci.entry)
162+
160163
# Emit the structured IR (uses original Julia SSA indices everywhere)
161164
emit_block!(ctx, ctx.sci.entry)
162165

test/execution/basic.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,3 +1154,66 @@ end
11541154
end
11551155
end
11561156

1157+
@testset "early return — taken" begin
1158+
function early_return_skip(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, flag::Int32)
1159+
pid = ct.bid(1)
1160+
tile = ct.load(a, pid, (16,))
1161+
if flag == Int32(0)
1162+
return nothing
1163+
end
1164+
ct.store(b, pid, tile .* 2.0f0)
1165+
return nothing
1166+
end
1167+
1168+
a = CUDA.rand(Float32, 64)
1169+
b = CUDA.zeros(Float32, 64)
1170+
ct.launch(early_return_skip, 4, a, b, Int32(0))
1171+
@test all(Array(b) .== 0.0f0)
1172+
end
1173+
1174+
@testset "early return — not taken" begin
1175+
function early_return_store(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, flag::Int32)
1176+
pid = ct.bid(1)
1177+
tile = ct.load(a, pid, (16,))
1178+
if flag == Int32(0)
1179+
return nothing
1180+
end
1181+
ct.store(b, pid, tile .* 2.0f0)
1182+
return nothing
1183+
end
1184+
1185+
a = CUDA.rand(Float32, 64)
1186+
b = CUDA.zeros(Float32, 64)
1187+
ct.launch(early_return_store, 4, a, b, Int32(1))
1188+
@test Array(b) Array(a) .* 2.0f0
1189+
end
1190+
1191+
@testset "multiple early returns" begin
1192+
function multi_early_return(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1},
1193+
flag1::Int32, flag2::Int32)
1194+
pid = ct.bid(1)
1195+
tile = ct.load(a, pid, (16,))
1196+
if flag1 == Int32(0)
1197+
return nothing
1198+
end
1199+
if flag2 == Int32(0)
1200+
return nothing
1201+
end
1202+
ct.store(b, pid, tile .* 2.0f0)
1203+
return nothing
1204+
end
1205+
1206+
a = CUDA.rand(Float32, 64)
1207+
1208+
b1 = CUDA.zeros(Float32, 64)
1209+
ct.launch(multi_early_return, 4, a, b1, Int32(1), Int32(1))
1210+
@test Array(b1) Array(a) .* 2.0f0
1211+
1212+
b2 = CUDA.zeros(Float32, 64)
1213+
ct.launch(multi_early_return, 4, a, b2, Int32(0), Int32(1))
1214+
@test all(Array(b2) .== 0.0f0)
1215+
1216+
b3 = CUDA.zeros(Float32, 64)
1217+
ct.launch(multi_early_return, 4, a, b3, Int32(1), Int32(0))
1218+
@test all(Array(b3) .== 0.0f0)
1219+
end

0 commit comments

Comments
 (0)