Skip to content

Commit e033d5e

Browse files
authored
Expose load/store optimization hints (#32)
1 parent be82b9b commit e033d5e

10 files changed

Lines changed: 765 additions & 88 deletions

File tree

src/bytecode/encodings.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ function encode_LoadViewTkoOp!(cb::CodeBuilder,
423423
token::Union{Value, Nothing}=nothing,
424424
memory_ordering::MemoryOrderingSemantics=MemoryWeak,
425425
memory_scope::Union{MemoryScope, Nothing}=nothing,
426-
optimization_hints::Union{Vector{UInt8}, Nothing}=nothing)
426+
optimization_hints::Union{OptimizationHints, Nothing}=nothing)
427427
encode_varint!(cb.buf, Opcode.LoadViewTkoOp)
428428
# Variadic result types
429429
encode_typeid_seq!(cb.buf, [tile_type, token_type])
@@ -447,7 +447,7 @@ function encode_LoadViewTkoOp!(cb::CodeBuilder,
447447
encode_enum!(cb.buf, memory_scope)
448448
end
449449
if optimization_hints !== nothing
450-
append!(cb.buf, optimization_hints)
450+
encode_opattr_optimization_hints!(cb, optimization_hints)
451451
end
452452

453453
# Operands
@@ -472,7 +472,7 @@ function encode_StoreViewTkoOp!(cb::CodeBuilder,
472472
token::Union{Value, Nothing}=nothing,
473473
memory_ordering::MemoryOrderingSemantics=MemoryWeak,
474474
memory_scope::Union{MemoryScope, Nothing}=nothing,
475-
optimization_hints::Union{Vector{UInt8}, Nothing}=nothing)
475+
optimization_hints::Union{OptimizationHints, Nothing}=nothing)
476476
encode_varint!(cb.buf, Opcode.StoreViewTkoOp)
477477
# Variadic result types (just token)
478478
encode_typeid_seq!(cb.buf, [token_type])
@@ -496,7 +496,7 @@ function encode_StoreViewTkoOp!(cb::CodeBuilder,
496496
encode_enum!(cb.buf, memory_scope)
497497
end
498498
if optimization_hints !== nothing
499-
append!(cb.buf, optimization_hints)
499+
encode_opattr_optimization_hints!(cb, optimization_hints)
500500
end
501501

502502
# Operands
@@ -541,7 +541,7 @@ function encode_LoadPtrTkoOp!(cb::CodeBuilder,
541541
token::Union{Value, Nothing}=nothing,
542542
memory_ordering::MemoryOrderingSemantics=MemoryWeak,
543543
memory_scope::Union{MemoryScope, Nothing}=nothing,
544-
optimization_hints::Union{Vector{UInt8}, Nothing}=nothing)
544+
optimization_hints::Union{OptimizationHints, Nothing}=nothing)
545545
encode_varint!(cb.buf, Opcode.LoadPtrTkoOp)
546546
# Result types
547547
encode_typeid!(cb.buf, result_type)
@@ -572,7 +572,7 @@ function encode_LoadPtrTkoOp!(cb::CodeBuilder,
572572
encode_enum!(cb.buf, memory_scope)
573573
end
574574
if optimization_hints !== nothing
575-
append!(cb.buf, optimization_hints)
575+
encode_opattr_optimization_hints!(cb, optimization_hints)
576576
end
577577

578578
# Operands
@@ -600,7 +600,7 @@ function encode_StorePtrTkoOp!(cb::CodeBuilder,
600600
token::Union{Value, Nothing}=nothing,
601601
memory_ordering::MemoryOrderingSemantics=MemoryWeak,
602602
memory_scope::Union{MemoryScope, Nothing}=nothing,
603-
optimization_hints::Union{Vector{UInt8}, Nothing}=nothing)
603+
optimization_hints::Union{OptimizationHints, Nothing}=nothing)
604604
encode_varint!(cb.buf, Opcode.StorePtrTkoOp)
605605
# Result type (token)
606606
encode_typeid!(cb.buf, token_type)
@@ -627,7 +627,7 @@ function encode_StorePtrTkoOp!(cb::CodeBuilder,
627627
encode_enum!(cb.buf, memory_scope)
628628
end
629629
if optimization_hints !== nothing
630-
append!(cb.buf, optimization_hints)
630+
encode_opattr_optimization_hints!(cb, optimization_hints)
631631
end
632632

633633
# Operands

src/bytecode/writer.jl

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -544,9 +544,75 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder,
544544
end
545545

546546
#=============================================================================
547-
EntryHints: Kernel-level compilation hints
547+
Optimization Hints
548548
=============================================================================#
549549

550+
"""
551+
encode_tagged_value!(cb, value)
552+
553+
Encode a value with its type tag.
554+
"""
555+
function encode_tagged_value!(buf::Vector{UInt8}, type_table::TypeTable, value::Bool)
556+
push!(buf, AttributeTag.Bool)
557+
push!(buf, value)
558+
end
559+
560+
function encode_tagged_value!(buf::Vector{UInt8}, type_table::TypeTable, value::Integer)
561+
push!(buf, AttributeTag.Integer)
562+
encode_typeid!(buf, I32(type_table))
563+
encode_varint!(buf, UInt32(value))
564+
end
565+
566+
"""
567+
Optimization hints for load/store operations.
568+
- `latency`: Optional latency hint (1-10), or nothing for default
569+
- `allow_tma`: Whether TMA (Tensor Memory Accelerator) is allowed (default: true)
570+
"""
571+
@kwdef struct LoadStoreHints
572+
latency::Union{Int, Nothing} = nothing
573+
allow_tma::Bool = true
574+
end
575+
576+
"""
577+
Optimization hints for load/store operations.
578+
- `hints_by_arch`: List of (SM architecture, load/store hints) pairs
579+
"""
580+
struct OptimizationHints
581+
hints_by_arch::Vector{Tuple{String, LoadStoreHints}}
582+
end
583+
584+
function make_load_store_hints(sm_arch::Union{String, Nothing}, hints::LoadStoreHints)
585+
isnothing(sm_arch) && throw(ArgumentError("sm_arch must be explicitly passed when load/store hints are present"))
586+
OptimizationHints([(sm_arch, hints)])
587+
end
588+
589+
function encode_opattr_optimization_hints!(cb::CodeBuilder, hints::OptimizationHints)
590+
# Outer dictionary: arch -> hints_dict
591+
encode_varint!(cb.buf, length(hints.hints_by_arch))
592+
for (arch, load_store_hints) in hints.hints_by_arch
593+
arch_id = cb.string_table[arch]
594+
encode_varint!(cb.buf, arch_id.id)
595+
# Encode hints as inner dictionary (tagged)
596+
encode_load_store_hints_dict!(cb, load_store_hints)
597+
end
598+
end
599+
600+
function encode_load_store_hints_dict!(cb::CodeBuilder, hints::LoadStoreHints)
601+
# Build list of (key, value) pairs for non-default hints
602+
items = Tuple{String, Any}[]
603+
hints.allow_tma || push!(items, ("allow_tma", false))
604+
isnothing(hints.latency) || push!(items, ("latency", hints.latency))
605+
606+
# Encode dictionary
607+
push!(cb.buf, AttributeTag.Dictionary)
608+
encode_varint!(cb.buf, length(items))
609+
for (key, value) in items
610+
key_id = cb.string_table[key]
611+
encode_varint!(cb.buf, key_id.id)
612+
encode_tagged_value!(cb.buf, cb.type_table, value)
613+
end
614+
end
615+
550616
"""
551617
Kernel-level compilation hints (num_ctas, occupancy).
552618
Encoded as a dictionary attribute in bytecode.
@@ -567,10 +633,6 @@ function validate_occupancy(occupancy::Union{Int, Nothing})
567633
1 <= occupancy <= 32 || throw(ArgumentError("occupancy must be between 1 and 32, got $occupancy"))
568634
end
569635

570-
"""
571-
Encode EntryHints as OptimizationHints format.
572-
Returns raw bytes for entry_hints parameter or nothing.
573-
"""
574636
function encode_entry_hints(writer::BytecodeWriter, sm_arch::Union{String, Nothing}, hints::EntryHints)
575637
validate_num_ctas(hints.num_ctas)
576638
validate_occupancy(hints.occupancy)
@@ -603,9 +665,7 @@ function encode_entry_hints(writer::BytecodeWriter, sm_arch::Union{String, Nothi
603665
for (key, value) in items
604666
key_id = writer.string_table[key]
605667
encode_varint!(buf, key_id.id)
606-
push!(buf, AttributeTag.Integer)
607-
encode_typeid!(buf, I32(writer.type_table))
608-
encode_varint!(buf, UInt32(value))
668+
encode_tagged_value!(buf, writer.type_table, value)
609669
end
610670

611671
return buf

src/compiler/codegen/kernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
1212
is_entry::Bool = true,
1313
num_ctas::Union{Int, Nothing} = nothing,
1414
occupancy::Union{Int, Nothing} = nothing)
15-
ctx = CGCtx(writer, target)
15+
ctx = CGCtx(writer, target, sm_arch)
1616
tt = ctx.tt
1717

1818
# Validate non-ghost argument types are concrete

src/compiler/intrinsics.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ end
2121

2222
emit_intrinsic!(ctx::CGCtx, @nospecialize(func), args) = missing
2323

24+
# Shared helper for creating load/store optimization hints
25+
function create_optimization_hints(ctx::CGCtx, latency::Union{Int, Nothing}, allow_tma::Bool=true)
26+
isnothing(latency) && allow_tma && return nothing
27+
isnothing(latency) || 1 <= latency <= 10 || error("latency must be between 1 and 10, got $latency")
28+
hints = LoadStoreHints(; latency, allow_tma)
29+
return make_load_store_hints(ctx.sm_arch, hints)
30+
end
31+
2432
include("intrinsics/core.jl")
2533
include("intrinsics/conversions.jl")
2634
include("intrinsics/arithmetic.jl")

src/compiler/intrinsics/memory.jl

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,28 @@
55
# cuda_tile.load_ptr_tko
66
@eval Intrinsics begin
77
"""
8-
load_ptr_tko(ptrs, mask=nothing, padding=nothing)
8+
load_ptr_tko(ptrs, latency, mask=nothing, padding=nothing)
99
1010
Load values from a tile of pointers.
1111
If mask is provided, masked-out positions return the padding value.
1212
Compiled to cuda_tile.load_ptr_tko.
13+
14+
Note: TMA (allow_tma) is not applicable for pointer-based loads as they
15+
support irregular access patterns incompatible with TMA requirements.
1316
"""
1417
@noinline function load_ptr_tko(ptrs::Tile{Ptr{T}, S},
18+
latency::Union{Int, Nothing}=nothing,
1519
mask::Union{Tile{Bool, S}, Nothing}=nothing,
1620
padding::Union{Tile{T, S}, Nothing}=nothing) where {T, S}
17-
donotdelete(ptrs, mask, padding)
21+
donotdelete(ptrs, latency, mask, padding)
1822
Tile{T, S}()
1923
end
2024
end
2125
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args)
2226
cb = ctx.cb
2327
tt = ctx.tt
2428

29+
# args: (ptrs, latency, mask?, padding?)
2530
# Get pointer tile (arg 1)
2631
ptrs_tv = emit_value!(ctx, args[1])
2732
ptrs_tv === nothing && error("load_ptr_tko: cannot resolve pointer tile")
@@ -36,29 +41,37 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args)
3641
result_tile_type = tile_type!(tt, dtype, tile_shape)
3742
token_type = Token(tt)
3843

39-
# Check if mask is provided (arg 2 is not nothing)
40-
has_mask = length(args) >= 2 && get_constant(ctx, args[2]) !== nothing
44+
# Extract latency hint (args[2])
45+
latency = get_constant(ctx, args[2])
46+
47+
# Create optimization hints if provided
48+
optimization_hints = create_optimization_hints(ctx, latency)
49+
50+
# Check if mask is provided (arg 3 is not nothing)
51+
has_mask = length(args) >= 3 && get_constant(ctx, args[3]) !== nothing
4152

4253
if has_mask
43-
# Get mask tile (arg 2)
44-
mask_tv = emit_value!(ctx, args[2])
54+
# Get mask tile (arg 3)
55+
mask_tv = emit_value!(ctx, args[3])
4556
mask_tv === nothing && error("load_ptr_tko: cannot resolve mask tile")
4657
mask = mask_tv.v
4758

48-
# Get padding tile (arg 3)
49-
padding_tv = emit_value!(ctx, args[3])
59+
# Get padding tile (arg 4)
60+
padding_tv = emit_value!(ctx, args[4])
5061
padding_tv === nothing && error("load_ptr_tko: cannot resolve padding tile")
5162
padding = padding_tv.v
5263

5364
# Load with mask and padding
5465
tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers;
5566
mask=mask,
5667
padding_value=padding,
57-
token=ctx.token)
68+
token=ctx.token,
69+
optimization_hints)
5870
else
5971
# Load without mask
6072
tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers;
61-
token=ctx.token)
73+
token=ctx.token,
74+
optimization_hints)
6275
end
6376
ctx.token = new_token
6477

@@ -71,22 +84,27 @@ end
7184
# cuda_tile.store_ptr_tko
7285
@eval Intrinsics begin
7386
"""
74-
store_ptr_tko(ptrs, values, mask=nothing)
87+
store_ptr_tko(ptrs, values, latency, mask=nothing)
7588
7689
Store values to a tile of pointers.
7790
If mask is provided, masked-out positions are not written.
7891
Compiled to cuda_tile.store_ptr_tko.
92+
93+
Note: TMA (allow_tma) is not applicable for pointer-based stores as they
94+
support irregular access patterns incompatible with TMA requirements.
7995
"""
8096
@noinline function store_ptr_tko(ptrs::Tile{Ptr{T}, S}, values::Tile{T, S},
97+
latency::Union{Int, Nothing},
8198
mask::Union{Tile{Bool, S}, Nothing}=nothing) where {T, S}
82-
donotdelete(ptrs, values, mask)
99+
donotdelete(ptrs, values, latency, mask)
83100
nothing
84101
end
85102
end
86103
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args)
87104
cb = ctx.cb
88105
tt = ctx.tt
89106

107+
# args: (ptrs, values, latency, mask?)
90108
# Get pointer tile (arg 1)
91109
ptrs_tv = emit_value!(ctx, args[1])
92110
ptrs_tv === nothing && error("store_ptr_tko: cannot resolve pointer tile")
@@ -99,23 +117,31 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args)
99117

100118
token_type = Token(tt)
101119

102-
# Check if mask is provided (arg 3 is not nothing)
103-
has_mask = length(args) >= 3 && get_constant(ctx, args[3]) !== nothing
120+
# Extract latency hint (args[3])
121+
latency = get_constant(ctx, args[3])
122+
123+
# Create optimization hints if provided
124+
optimization_hints = create_optimization_hints(ctx, latency)
125+
126+
# Check if mask is provided (arg 4 is not nothing)
127+
has_mask = length(args) >= 4 && get_constant(ctx, args[4]) !== nothing
104128

105129
if has_mask
106-
# Get mask tile (arg 3)
107-
mask_tv = emit_value!(ctx, args[3])
130+
# Get mask tile (arg 4)
131+
mask_tv = emit_value!(ctx, args[4])
108132
mask_tv === nothing && error("store_ptr_tko: cannot resolve mask tile")
109133
mask = mask_tv.v
110134

111135
# Store with mask
112136
new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values;
113137
mask=mask,
114-
token=ctx.token)
138+
token=ctx.token,
139+
optimization_hints)
115140
else
116141
# Store without mask
117142
new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values;
118-
token=ctx.token)
143+
token=ctx.token,
144+
optimization_hints)
119145
end
120146
ctx.token = new_token
121147

0 commit comments

Comments
 (0)