@@ -1154,3 +1154,66 @@ end
11541154 end
11551155end
11561156
1157+ @testset " early return — taken" begin
1158+ function early_return_skip (a:: ct.TileArray{Float32,1} , b:: ct.TileArray{Float32,1} , flag:: Int32 )
1159+ pid = ct. bid (1 )
1160+ tile = ct. load (a, pid, (16 ,))
1161+ if flag == Int32 (0 )
1162+ return nothing
1163+ end
1164+ ct. store (b, pid, tile .* 2.0f0 )
1165+ return nothing
1166+ end
1167+
1168+ a = CUDA. rand (Float32, 64 )
1169+ b = CUDA. zeros (Float32, 64 )
1170+ ct. launch (early_return_skip, 4 , a, b, Int32 (0 ))
1171+ @test all (Array (b) .== 0.0f0 )
1172+ end
1173+
1174+ @testset " early return — not taken" begin
1175+ function early_return_store (a:: ct.TileArray{Float32,1} , b:: ct.TileArray{Float32,1} , flag:: Int32 )
1176+ pid = ct. bid (1 )
1177+ tile = ct. load (a, pid, (16 ,))
1178+ if flag == Int32 (0 )
1179+ return nothing
1180+ end
1181+ ct. store (b, pid, tile .* 2.0f0 )
1182+ return nothing
1183+ end
1184+
1185+ a = CUDA. rand (Float32, 64 )
1186+ b = CUDA. zeros (Float32, 64 )
1187+ ct. launch (early_return_store, 4 , a, b, Int32 (1 ))
1188+ @test Array (b) ≈ Array (a) .* 2.0f0
1189+ end
1190+
1191+ @testset " multiple early returns" begin
1192+ function multi_early_return (a:: ct.TileArray{Float32,1} , b:: ct.TileArray{Float32,1} ,
1193+ flag1:: Int32 , flag2:: Int32 )
1194+ pid = ct. bid (1 )
1195+ tile = ct. load (a, pid, (16 ,))
1196+ if flag1 == Int32 (0 )
1197+ return nothing
1198+ end
1199+ if flag2 == Int32 (0 )
1200+ return nothing
1201+ end
1202+ ct. store (b, pid, tile .* 2.0f0 )
1203+ return nothing
1204+ end
1205+
1206+ a = CUDA. rand (Float32, 64 )
1207+
1208+ b1 = CUDA. zeros (Float32, 64 )
1209+ ct. launch (multi_early_return, 4 , a, b1, Int32 (1 ), Int32 (1 ))
1210+ @test Array (b1) ≈ Array (a) .* 2.0f0
1211+
1212+ b2 = CUDA. zeros (Float32, 64 )
1213+ ct. launch (multi_early_return, 4 , a, b2, Int32 (0 ), Int32 (1 ))
1214+ @test all (Array (b2) .== 0.0f0 )
1215+
1216+ b3 = CUDA. zeros (Float32, 64 )
1217+ ct. launch (multi_early_return, 4 , a, b3, Int32 (1 ), Int32 (0 ))
1218+ @test all (Array (b3) .== 0.0f0 )
1219+ end
0 commit comments