Skip to content

Commit 44c4823

Browse files
authored
Merge pull request #111 from JuliaGPU/tb/cuda_13.2
Fixes for CUDA 13.2
2 parents 29e2157 + b85ebc6 commit 44c4823

8 files changed

Lines changed: 112 additions & 38 deletions

File tree

.buildkite/pipeline.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@ steps:
66
- JuliaCI/julia#v1:
77
version: "{{matrix.julia}}"
88
- JuliaCI/julia-test#v1:
9-
test_args: "--quickfail"
9+
coverage: false
10+
commands: |
11+
unset LD_LIBRARY_PATH
1012
agents:
1113
queue: "juliagpu"
1214
cuda: "*"
1315
gpu: "a100"
14-
timeout_in_minutes: 90
16+
timeout_in_minutes: 15
1517
matrix:
1618
setup:
1719
julia:
1820
- "1.11"
1921
- "1.12"
20-
- "1.13"

ext/CUDAExt.jl

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module CUDAExt
22

33
using cuTile
4-
using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code, sanitize_name,
5-
constant_eltype, constant_value, is_ghost_type
4+
using cuTile: TileArray, Constant, CGOpts, CuTileResults, DEFAULT_BYTECODE_VERSION,
5+
emit_code, sanitize_name, constant_eltype, constant_value, is_ghost_type
66

77
using CompilerCaching: CacheView, method_instance, results
88

@@ -13,6 +13,16 @@ using CUDA_Compiler_jll
1313

1414
public launch
1515

16+
function run_and_collect(cmd)
17+
stdout = Pipe()
18+
proc = run(pipeline(ignorestatus(cmd); stdout, stderr=stdout), wait=false)
19+
close(stdout.in)
20+
reader = Threads.@spawn String(read(stdout))
21+
Base.wait(proc)
22+
log = strip(fetch(reader))
23+
return proc, log
24+
end
25+
1626
"""
1727
check_tile_ir_support()
1828
@@ -38,6 +48,9 @@ function check_tile_ir_support()
3848
else
3949
error("Tile IR is not supported on compute capability $cap ($sm_arch)")
4050
end
51+
52+
# Return bytecode version matching the toolkit
53+
return VersionNumber(cuda_ver.major, cuda_ver.minor)
4154
end
4255

4356
"""
@@ -58,12 +71,29 @@ function emit_binary(cache::CacheView, mi::Core.MethodInstance;
5871
# Run tileiras to produce CUBIN
5972
input_path = tempname() * ".tile"
6073
output_path = tempname() * ".cubin"
74+
compiled = false
6175
try
6276
write(input_path, bytecode)
63-
run(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $(opts.sm_arch) -O$(opts.opt_level)`)
77+
cmd = addenv(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $(opts.sm_arch) -O$(opts.opt_level)`,
78+
"CUDA_ROOT" => CUDA_Compiler_jll.artifact_dir)
79+
proc, log = run_and_collect(cmd)
80+
if !success(proc)
81+
reason = proc.termsignal > 0 ? "tileiras received signal $(proc.termsignal)" :
82+
"tileiras exited with code $(proc.exitcode)"
83+
msg = "Failed to compile Tile IR ($reason)"
84+
if !isempty(log)
85+
msg *= "\n" * log
86+
end
87+
msg *= "\nIf you think this is a bug, please file an issue and attach $(input_path)"
88+
if parse(Bool, get(ENV, "BUILDKITE", "false"))
89+
run(`buildkite-agent artifact upload $(input_path)`)
90+
end
91+
error(msg)
92+
end
93+
compiled = true
6494
res.cuda_bin = read(output_path)
6595
finally
66-
rm(input_path, force=true)
96+
compiled && rm(input_path, force=true)
6797
rm(output_path, force=true)
6898
end
6999

@@ -135,7 +165,7 @@ function cuTile.launch(@nospecialize(f), grid, args...;
135165
opt_level::Int=3,
136166
num_ctas::Union{Int, Nothing}=nothing,
137167
occupancy::Union{Int, Nothing}=nothing)
138-
check_tile_ir_support()
168+
bytecode_version = check_tile_ir_support()
139169

140170
# Convert CuArray -> TileArray (and other conversions)
141171
tile_args = map(to_tile_arg, args)
@@ -166,7 +196,8 @@ function cuTile.launch(@nospecialize(f), grid, args...;
166196
end
167197

168198
# Create cache view with compilation options as sharding keys
169-
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy)
199+
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy,
200+
bytecode_version=bytecode_version)
170201
cache = CacheView{CuTileResults}((:cuTile, opts), world)
171202

172203
# Run cached compilation

src/bytecode/encodings.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,9 +1122,14 @@ Example:
11221122
function encode_ForOp!(body::Function, cb::CodeBuilder,
11231123
result_types::Vector{TypeId}, iv_type::TypeId,
11241124
lower::Value, upper::Value, step::Value,
1125-
init_values::Vector{Value})
1125+
init_values::Vector{Value};
1126+
unsigned_cmp::Bool=false)
11261127
encode_varint!(cb.buf, Opcode.ForOp)
11271128
encode_typeid_seq!(cb.buf, result_types)
1129+
# Flags
1130+
if cb.version >= v"13.2"
1131+
encode_varint!(cb.buf, unsigned_cmp ? 1 : 0)
1132+
end
11281133
# Operands: lower, upper, step, init_values...
11291134
encode_varint!(cb.buf, 3 + length(init_values))
11301135
encode_operand!(cb.buf, lower)
@@ -1558,7 +1563,9 @@ function encode_NegIOp!(cb::CodeBuilder, result_type::TypeId, source::Value;
15581563
overflow::IntegerOverflow=OverflowNone)
15591564
encode_varint!(cb.buf, Opcode.NegIOp)
15601565
encode_typeid!(cb.buf, result_type)
1561-
encode_enum!(cb.buf, overflow)
1566+
if cb.version >= v"13.2"
1567+
encode_enum!(cb.buf, overflow)
1568+
end
15621569
encode_operand!(cb.buf, source)
15631570
return new_op!(cb)
15641571
end
@@ -1956,9 +1963,13 @@ end
19561963
Element-wise hyperbolic tangent.
19571964
Opcode: 106
19581965
"""
1959-
function encode_TanHOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
1966+
function encode_TanHOp!(cb::CodeBuilder, result_type::TypeId, source::Value;
1967+
rounding_mode::RoundingMode=RoundingFull)
19601968
encode_varint!(cb.buf, Opcode.TanHOp)
19611969
encode_typeid!(cb.buf, result_type)
1970+
if cb.version >= v"13.2"
1971+
encode_enum!(cb.buf, rounding_mode)
1972+
end
19621973
encode_operand!(cb.buf, source)
19631974
return new_op!(cb)
19641975
end

src/bytecode/writer.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Bytecode file writer - handles sections and overall structure
22

33
# Bytecode version
4-
const BYTECODE_VERSION = (13, 1, 0)
4+
const DEFAULT_BYTECODE_VERSION = v"13.1"
55

66
# Magic number
77
const MAGIC = UInt8[0x7f, 0x54, 0x69, 0x6c, 0x65, 0x49, 0x52, 0x00] # "\x7fTileIR\x00"
@@ -97,9 +97,11 @@ mutable struct CodeBuilder
9797
next_value_id::Int
9898
cur_debug_attr::DebugAttrId
9999
num_ops::Int
100+
version::VersionNumber
100101
end
101102

102-
function CodeBuilder(string_table::StringTable, constant_table::ConstantTable, type_table::TypeTable)
103+
function CodeBuilder(string_table::StringTable, constant_table::ConstantTable, type_table::TypeTable;
104+
version::VersionNumber=DEFAULT_BYTECODE_VERSION)
103105
CodeBuilder(
104106
UInt8[],
105107
string_table,
@@ -108,7 +110,8 @@ function CodeBuilder(string_table::StringTable, constant_table::ConstantTable, t
108110
DebugAttrId[],
109111
0,
110112
DebugAttrId(0), # No debug info
111-
0
113+
0,
114+
version
112115
)
113116
end
114117

@@ -374,9 +377,10 @@ mutable struct BytecodeWriter
374377
debug_attr_table::DebugAttrTable
375378
debug_info::Vector{Vector{DebugAttrId}}
376379
num_functions::Int
380+
version::VersionNumber
377381
end
378382

379-
function BytecodeWriter()
383+
function BytecodeWriter(; version::VersionNumber=DEFAULT_BYTECODE_VERSION)
380384
string_table = StringTable()
381385
BytecodeWriter(
382386
UInt8[],
@@ -385,21 +389,21 @@ function BytecodeWriter()
385389
TypeTable(),
386390
DebugAttrTable(string_table),
387391
Vector{Vector{DebugAttrId}}[],
388-
0
392+
0,
393+
version
389394
)
390395
end
391396

392397
"""
393398
Write the bytecode header.
394399
"""
395-
function write_header!(buf::Vector{UInt8})
400+
function write_header!(buf::Vector{UInt8}, version::VersionNumber)
396401
append!(buf, MAGIC)
397-
major, minor, tag = BYTECODE_VERSION
398-
push!(buf, UInt8(major))
399-
push!(buf, UInt8(minor))
400-
# Tag as 2-byte little-endian
401-
push!(buf, UInt8(tag & 0xff))
402-
push!(buf, UInt8((tag >> 8) & 0xff))
402+
push!(buf, UInt8(version.major))
403+
push!(buf, UInt8(version.minor))
404+
# Patch as 2-byte little-endian
405+
push!(buf, UInt8(version.patch & 0xff))
406+
push!(buf, UInt8((version.patch >> 8) & 0xff))
403407
end
404408

405409
"""
@@ -486,8 +490,9 @@ end
486490
Write complete bytecode to a buffer.
487491
Returns the buffer with all sections.
488492
"""
489-
function write_bytecode!(f::Function, num_functions::Int)
490-
writer = BytecodeWriter()
493+
function write_bytecode!(f::Function, num_functions::Int;
494+
version::VersionNumber=DEFAULT_BYTECODE_VERSION)
495+
writer = BytecodeWriter(; version)
491496

492497
# Function section content
493498
func_buf = UInt8[]
@@ -502,7 +507,7 @@ function write_bytecode!(f::Function, num_functions::Int)
502507

503508
# Build final output
504509
buf = UInt8[]
505-
write_header!(buf)
510+
write_header!(buf, version)
506511

507512
# Sections in order: Func, Global (if any), Constant, Debug, Type, String, End
508513
write_section!(buf, Section.Func, func_buf, 8)
@@ -574,7 +579,8 @@ function add_function!(writer::BytecodeWriter, func_buf::Vector{UInt8},
574579
end
575580

576581
# Create code builder for function body
577-
cb = CodeBuilder(writer.string_table, writer.constant_table, writer.type_table)
582+
cb = CodeBuilder(writer.string_table, writer.constant_table, writer.type_table;
583+
version=writer.version)
578584

579585
return cb
580586
end

src/compiler/interface.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ const CGOpts = @NamedTuple{
291291
sm_arch::Union{String, Nothing},
292292
opt_level::Int,
293293
num_ctas::Union{Int, Nothing},
294-
occupancy::Union{Int, Nothing}
294+
occupancy::Union{Int, Nothing},
295+
bytecode_version::VersionNumber
295296
}
296297

297298
# Results struct for caching compilation phases
@@ -394,7 +395,7 @@ function emit_code(cache::CacheView, mi::Core.MethodInstance;
394395
opts = cache.owner[2]
395396

396397
# Generate Tile IR bytecode
397-
bytecode = write_bytecode!(1) do writer, func_buf
398+
bytecode = write_bytecode!(1; version=opts.bytecode_version) do writer, func_buf
398399
emit_kernel!(writer, func_buf, sci, rettype;
399400
name = sanitize_name(string(mi.def.name)),
400401
sm_arch = opts.sm_arch,
@@ -508,6 +509,7 @@ function code_tiled(io::IO, @nospecialize(f), @nospecialize(argtypes);
508509
opt_level::Int=3,
509510
num_ctas::Union{Int, Nothing}=nothing,
510511
occupancy::Union{Int, Nothing}=nothing,
512+
bytecode_version::VersionNumber=DEFAULT_BYTECODE_VERSION,
511513
world::UInt=Base.get_world_counter())
512514
# Strip Constant types from argtypes for MI lookup, build const_argtypes
513515
stripped, const_argtypes = process_const_argtypes(f, argtypes)
@@ -518,7 +520,8 @@ function code_tiled(io::IO, @nospecialize(f), @nospecialize(argtypes);
518520
mi = @something(method_instance(f, stripped; world, method_table=cuTileMethodTable),
519521
method_instance(f, stripped; world),
520522
throw(MethodError(f, stripped)))
521-
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy)
523+
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy,
524+
bytecode_version=bytecode_version)
522525
cache = CacheView{CuTileResults}((:cuTile, opts), world)
523526
bytecode = emit_code(cache, mi; const_argtypes)
524527
print(io, disassemble_tileir(bytecode))

src/language/atomics.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ end
107107
S === () ? Intrinsics.to_scalar(result) : result
108108
end
109109

110+
# Convert mismatched scalar/tile types to match array element type
111+
@inline function atomic_cas(array::TileArray{T}, indices,
112+
expected::TileOrScalar, desired::TileOrScalar;
113+
memory_order::Int=MemoryOrder.AcqRel,
114+
memory_scope::Int=MemScope.Device) where {T}
115+
atomic_cas(array, indices, T(expected), T(desired); memory_order, memory_scope)
116+
end
117+
110118
# ============================================================================
111119
# Atomic RMW operations (atomic_add, atomic_xchg)
112120
# ============================================================================
@@ -150,4 +158,11 @@ for op in (:add, :xchg)
150158
result = Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope)
151159
S === () ? Intrinsics.to_scalar(result) : result
152160
end
161+
162+
# Convert mismatched scalar/tile types to match array element type
163+
@eval @inline function $fname(array::TileArray{T}, indices, val::TileOrScalar;
164+
memory_order::Int=MemoryOrder.AcqRel,
165+
memory_scope::Int=MemScope.Device) where {T}
166+
$fname(array, indices, T(val); memory_order, memory_scope)
167+
end
153168
end

test/execution/atomics.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ end
4343
# Test atomic_xchg: each thread exchanges, last one wins
4444
function atomic_xchg_kernel(arr::ct.TileArray{Int,1})
4545
bid = ct.bid(1)
46-
ct.atomic_xchg(arr, 1, bid + 1;
46+
# bid is 1-indexed (1..n_blocks), val is auto-converted from Int32 to Int
47+
ct.atomic_xchg(arr, 1, bid;
4748
memory_order=ct.MemoryOrder.AcqRel)
4849
return
4950
end

test/execution/hints.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@ using CUDA
1818
b = CUDA.ones(Float32, n) .* 2
1919
c = CUDA.zeros(Float32, n)
2020

21-
ct.launch(vadd_kernel_num_ctas, 64, a, b, c; num_ctas=2)
22-
23-
@test Array(c) ones(Float32, n) .* 3
21+
if capability(device()) >= v"10"
22+
ct.launch(vadd_kernel_num_ctas, 64, a, b, c; num_ctas=2)
23+
@test Array(c) ones(Float32, n) .* 3
24+
else
25+
@test_throws "num_cta_in_cga" ct.launch(vadd_kernel_num_ctas, 64, a, b, c; num_ctas=2)
26+
end
2427
end
2528

2629
@testset "launch with occupancy" begin
@@ -60,9 +63,12 @@ end
6063
b = CUDA.ones(Float32, n) .* 2
6164
c = CUDA.zeros(Float32, n)
6265

63-
ct.launch(vadd_kernel_both_hints, 64, a, b, c; num_ctas=4, occupancy=8)
64-
65-
@test Array(c) ones(Float32, n) .* 3
66+
if capability(device()) >= v"10"
67+
ct.launch(vadd_kernel_both_hints, 64, a, b, c; num_ctas=4, occupancy=8)
68+
@test Array(c) ones(Float32, n) .* 3
69+
else
70+
@test_throws "num_cta_in_cga" ct.launch(vadd_kernel_both_hints, 64, a, b, c; num_ctas=4, occupancy=8)
71+
end
6672
end
6773

6874
end

0 commit comments

Comments
 (0)