Skip to content

Commit 6aeec87

Browse files
committed
Update IR codegen for 13.2.
1 parent dee5390 commit 6aeec87

4 files changed

Lines changed: 50 additions & 26 deletions

File tree

ext/CUDAExt.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module CUDAExt
22

33
using cuTile
4-
using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code, sanitize_name,
5-
constant_eltype, constant_value, is_ghost_type
4+
using cuTile: TileArray, Constant, CGOpts, CuTileResults, DEFAULT_BYTECODE_VERSION,
5+
emit_code, sanitize_name, constant_eltype, constant_value, is_ghost_type
66

77
using CompilerCaching: CacheView, method_instance, results
88

@@ -38,6 +38,9 @@ function check_tile_ir_support()
3838
else
3939
error("Tile IR is not supported on compute capability $cap ($sm_arch)")
4040
end
41+
42+
# Return bytecode version matching the toolkit
43+
return VersionNumber(cuda_ver.major, cuda_ver.minor)
4144
end
4245

4346
"""
@@ -137,7 +140,7 @@ function cuTile.launch(@nospecialize(f), grid, args...;
137140
opt_level::Int=3,
138141
num_ctas::Union{Int, Nothing}=nothing,
139142
occupancy::Union{Int, Nothing}=nothing)
140-
check_tile_ir_support()
143+
bytecode_version = check_tile_ir_support()
141144

142145
# Convert CuArray -> TileArray (and other conversions)
143146
tile_args = map(to_tile_arg, args)
@@ -168,7 +171,8 @@ function cuTile.launch(@nospecialize(f), grid, args...;
168171
end
169172

170173
# Create cache view with compilation options as sharding keys
171-
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy)
174+
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy,
175+
bytecode_version=bytecode_version)
172176
cache = CacheView{CuTileResults}((:cuTile, opts), world)
173177

174178
# Run cached compilation

src/bytecode/encodings.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,9 +1122,14 @@ Example:
11221122
function encode_ForOp!(body::Function, cb::CodeBuilder,
11231123
result_types::Vector{TypeId}, iv_type::TypeId,
11241124
lower::Value, upper::Value, step::Value,
1125-
init_values::Vector{Value})
1125+
init_values::Vector{Value};
1126+
unsigned_cmp::Bool=false)
11261127
encode_varint!(cb.buf, Opcode.ForOp)
11271128
encode_typeid_seq!(cb.buf, result_types)
1129+
# Flags
1130+
if cb.version >= v"13.2"
1131+
encode_varint!(cb.buf, unsigned_cmp ? 1 : 0)
1132+
end
11281133
# Operands: lower, upper, step, init_values...
11291134
encode_varint!(cb.buf, 3 + length(init_values))
11301135
encode_operand!(cb.buf, lower)
@@ -1558,7 +1563,9 @@ function encode_NegIOp!(cb::CodeBuilder, result_type::TypeId, source::Value;
15581563
overflow::IntegerOverflow=OverflowNone)
15591564
encode_varint!(cb.buf, Opcode.NegIOp)
15601565
encode_typeid!(cb.buf, result_type)
1561-
encode_enum!(cb.buf, overflow)
1566+
if cb.version >= v"13.2"
1567+
encode_enum!(cb.buf, overflow)
1568+
end
15621569
encode_operand!(cb.buf, source)
15631570
return new_op!(cb)
15641571
end
@@ -1956,9 +1963,13 @@ end
19561963
Element-wise hyperbolic tangent.
19571964
Opcode: 106
19581965
"""
1959-
function encode_TanHOp!(cb::CodeBuilder, result_type::TypeId, source::Value)
1966+
function encode_TanHOp!(cb::CodeBuilder, result_type::TypeId, source::Value;
1967+
rounding_mode::RoundingMode=RoundingFull)
19601968
encode_varint!(cb.buf, Opcode.TanHOp)
19611969
encode_typeid!(cb.buf, result_type)
1970+
if cb.version >= v"13.2"
1971+
encode_enum!(cb.buf, rounding_mode)
1972+
end
19621973
encode_operand!(cb.buf, source)
19631974
return new_op!(cb)
19641975
end

src/bytecode/writer.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Bytecode file writer - handles sections and overall structure
22

33
# Bytecode version
4-
const BYTECODE_VERSION = (13, 1, 0)
4+
const DEFAULT_BYTECODE_VERSION = v"13.1"
55

66
# Magic number
77
const MAGIC = UInt8[0x7f, 0x54, 0x69, 0x6c, 0x65, 0x49, 0x52, 0x00] # "\x7fTileIR\x00"
@@ -97,9 +97,11 @@ mutable struct CodeBuilder
9797
next_value_id::Int
9898
cur_debug_attr::DebugAttrId
9999
num_ops::Int
100+
version::VersionNumber
100101
end
101102

102-
function CodeBuilder(string_table::StringTable, constant_table::ConstantTable, type_table::TypeTable)
103+
function CodeBuilder(string_table::StringTable, constant_table::ConstantTable, type_table::TypeTable;
104+
version::VersionNumber=DEFAULT_BYTECODE_VERSION)
103105
CodeBuilder(
104106
UInt8[],
105107
string_table,
@@ -108,7 +110,8 @@ function CodeBuilder(string_table::StringTable, constant_table::ConstantTable, t
108110
DebugAttrId[],
109111
0,
110112
DebugAttrId(0), # No debug info
111-
0
113+
0,
114+
version
112115
)
113116
end
114117

@@ -374,9 +377,10 @@ mutable struct BytecodeWriter
374377
debug_attr_table::DebugAttrTable
375378
debug_info::Vector{Vector{DebugAttrId}}
376379
num_functions::Int
380+
version::VersionNumber
377381
end
378382

379-
function BytecodeWriter()
383+
function BytecodeWriter(; version::VersionNumber=DEFAULT_BYTECODE_VERSION)
380384
string_table = StringTable()
381385
BytecodeWriter(
382386
UInt8[],
@@ -385,21 +389,21 @@ function BytecodeWriter()
385389
TypeTable(),
386390
DebugAttrTable(string_table),
387391
Vector{Vector{DebugAttrId}}[],
388-
0
392+
0,
393+
version
389394
)
390395
end
391396

392397
"""
393398
Write the bytecode header.
394399
"""
395-
function write_header!(buf::Vector{UInt8})
400+
function write_header!(buf::Vector{UInt8}, version::VersionNumber)
396401
append!(buf, MAGIC)
397-
major, minor, tag = BYTECODE_VERSION
398-
push!(buf, UInt8(major))
399-
push!(buf, UInt8(minor))
400-
# Tag as 2-byte little-endian
401-
push!(buf, UInt8(tag & 0xff))
402-
push!(buf, UInt8((tag >> 8) & 0xff))
402+
push!(buf, UInt8(version.major))
403+
push!(buf, UInt8(version.minor))
404+
# Patch as 2-byte little-endian
405+
push!(buf, UInt8(version.patch & 0xff))
406+
push!(buf, UInt8((version.patch >> 8) & 0xff))
403407
end
404408

405409
"""
@@ -486,8 +490,9 @@ end
486490
Write complete bytecode to a buffer.
487491
Returns the buffer with all sections.
488492
"""
489-
function write_bytecode!(f::Function, num_functions::Int)
490-
writer = BytecodeWriter()
493+
function write_bytecode!(f::Function, num_functions::Int;
494+
version::VersionNumber=DEFAULT_BYTECODE_VERSION)
495+
writer = BytecodeWriter(; version)
491496

492497
# Function section content
493498
func_buf = UInt8[]
@@ -502,7 +507,7 @@ function write_bytecode!(f::Function, num_functions::Int)
502507

503508
# Build final output
504509
buf = UInt8[]
505-
write_header!(buf)
510+
write_header!(buf, version)
506511

507512
# Sections in order: Func, Global (if any), Constant, Debug, Type, String, End
508513
write_section!(buf, Section.Func, func_buf, 8)
@@ -574,7 +579,8 @@ function add_function!(writer::BytecodeWriter, func_buf::Vector{UInt8},
574579
end
575580

576581
# Create code builder for function body
577-
cb = CodeBuilder(writer.string_table, writer.constant_table, writer.type_table)
582+
cb = CodeBuilder(writer.string_table, writer.constant_table, writer.type_table;
583+
version=writer.version)
578584

579585
return cb
580586
end

src/compiler/interface.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ const CGOpts = @NamedTuple{
291291
sm_arch::Union{String, Nothing},
292292
opt_level::Int,
293293
num_ctas::Union{Int, Nothing},
294-
occupancy::Union{Int, Nothing}
294+
occupancy::Union{Int, Nothing},
295+
bytecode_version::VersionNumber
295296
}
296297

297298
# Results struct for caching compilation phases
@@ -394,7 +395,7 @@ function emit_code(cache::CacheView, mi::Core.MethodInstance;
394395
opts = cache.owner[2]
395396

396397
# Generate Tile IR bytecode
397-
bytecode = write_bytecode!(1) do writer, func_buf
398+
bytecode = write_bytecode!(1; version=opts.bytecode_version) do writer, func_buf
398399
emit_kernel!(writer, func_buf, sci, rettype;
399400
name = sanitize_name(string(mi.def.name)),
400401
sm_arch = opts.sm_arch,
@@ -508,6 +509,7 @@ function code_tiled(io::IO, @nospecialize(f), @nospecialize(argtypes);
508509
opt_level::Int=3,
509510
num_ctas::Union{Int, Nothing}=nothing,
510511
occupancy::Union{Int, Nothing}=nothing,
512+
bytecode_version::VersionNumber=DEFAULT_BYTECODE_VERSION,
511513
world::UInt=Base.get_world_counter())
512514
# Strip Constant types from argtypes for MI lookup, build const_argtypes
513515
stripped, const_argtypes = process_const_argtypes(f, argtypes)
@@ -518,7 +520,8 @@ function code_tiled(io::IO, @nospecialize(f), @nospecialize(argtypes);
518520
mi = @something(method_instance(f, stripped; world, method_table=cuTileMethodTable),
519521
method_instance(f, stripped; world),
520522
throw(MethodError(f, stripped)))
521-
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy)
523+
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy,
524+
bytecode_version=bytecode_version)
522525
cache = CacheView{CuTileResults}((:cuTile, opts), world)
523526
bytecode = emit_code(cache, mi; const_argtypes)
524527
print(io, disassemble_tileir(bytecode))

0 commit comments

Comments
 (0)