Skip to content

Commit 0353a1a

Browse files
maleadtclaude
andcommitted
Unify scalar and tile-indexed atomic intrinsics
Replace 6 intrinsics (3 scalar + 3 tile) with 3 unified ones that take (ptr_tile, val, mask, ...) matching Python cuTile's design. Both paths now compute pointers via Intrinsics.offset in the language layer, with mask=nothing for scalar indices (no mask in bytecode) and Tile{Bool} for tile-indexed (bounds mask passed through). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 732bf63 commit 0353a1a

3 files changed

Lines changed: 199 additions & 286 deletions

File tree

src/compiler/intrinsics/atomics.jl

Lines changed: 103 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -30,251 +30,135 @@ function memory_scope_to_scope(scope::Int)
3030
end
3131
end
3232

33-
# cuda_tile.atomic_cas_tko
34-
@intrinsic atomic_cas(array, index, expected, desired,
35-
memory_order, memory_scope)
36-
tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
37-
efunc(::typeof(Intrinsics.atomic_cas), effects::CC.Effects) =
38-
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
39-
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args)
40-
cb = ctx.cb
41-
tt = ctx.tt
42-
43-
# args: (array, index, expected, desired, memory_order, memory_scope)
44-
array_arg = args[1]
45-
46-
# Get array info
47-
arg_idx = extract_argument_index(array_arg)
48-
is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx)
49-
50-
if !is_tilearray
51-
throw(IRError("atomic_cas requires a TileArray argument"))
52-
end
53-
54-
ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr)
55-
isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument"))
56-
array_val = ptr_vals[1]
57-
tilearray_type = get_arg_type(ctx, arg_idx)
58-
elem_type = eltype(tilearray_type)
59-
60-
# Get expected and desired values
61-
expected_tv = emit_value!(ctx, args[3])
62-
expected_tv === nothing && throw(IRError("atomic_cas requires expected value"))
63-
desired_tv = emit_value!(ctx, args[4])
64-
desired_tv === nothing && throw(IRError("atomic_cas requires desired value"))
65-
66-
# Get memory order and scope from args
67-
memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic_cas requires constant memory_order"))
68-
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic_cas requires constant memory_scope"))
69-
70-
# Create result type (0D tile of element type)
71-
dtype = julia_to_tile_dtype!(tt, elem_type)
72-
result_tile_type = tile_type!(tt, dtype, Int[])
73-
token_type = Token(tt)
74-
75-
# Get index and create pointer type
76-
index_tv = emit_value!(ctx, args[2])
77-
ptr_type = pointer_type!(tt, dtype)
78-
ptr_tile_type = tile_type!(tt, ptr_type, Int[])
79-
80-
# Compute pointer using OffsetOp (handles any integer index type)
81-
pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v)
82-
83-
# Emit atomic CAS
84-
mem_ordering = memory_order_to_semantics(memory_order)
85-
mem_scope = memory_scope_to_scope(memory_scope)
86-
87-
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, pointers,
88-
expected_tv.v, desired_tv.v;
89-
token=ctx.token,
90-
memory_ordering=mem_ordering,
91-
memory_scope=mem_scope)
92-
ctx.token = new_token
33+
"""
34+
atomic_rmw_tfunc(ptrs) -> Type
9335
94-
# Return scalar type (not Tile) to match the intrinsic signature
95-
CGVal(old_val, result_tile_type, elem_type, Int[])
36+
Shared tfunc for atomic RMW operations (add, xchg).
37+
Returns raw T for 0D pointer tiles, Tile{T, S} for N-D.
38+
"""
39+
function atomic_rmw_tfunc(𝕃, @nospecialize(ptrs), @nospecialize args...)
40+
ptrs_type = CC.widenconst(ptrs)
41+
ptrs_type isa DataType && ptrs_type <: Tile || return nothing
42+
ptr_type = eltype(ptrs_type)
43+
ptr_type <: Ptr || return nothing
44+
T = eltype(ptr_type)
45+
S = ptrs_type.parameters[2]
46+
S === Tuple{} && return T
47+
return Tile{T, S}
9648
end
9749

98-
# cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations)
50+
# Shared codegen helper for atomic RMW operations (add, xchg)
9951
function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
10052
cb = ctx.cb
10153
tt = ctx.tt
10254

103-
# args: (array, index, val, memory_order, memory_scope)
104-
array_arg = args[1]
105-
106-
# Get array info
107-
arg_idx = extract_argument_index(array_arg)
108-
is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx)
55+
# args: (ptr_tile, val, mask, memory_order, memory_scope)
56+
ptr_tv = emit_value!(ctx, args[1])
57+
ptr_tv === nothing && throw(IRError("atomic RMW requires ptr_tile"))
58+
val_tv = emit_value!(ctx, args[2])
59+
val_tv === nothing && throw(IRError("atomic RMW requires value"))
10960

110-
if !is_tilearray
111-
throw(IRError("atomic operations require a TileArray argument"))
112-
end
61+
# Check if mask is provided (ghost Nothing = no mask)
62+
has_mask = get_constant(ctx, args[3]) !== nothing
11363

114-
ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr)
115-
isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument"))
116-
array_val = ptr_vals[1]
117-
tilearray_type = get_arg_type(ctx, arg_idx)
118-
elem_type = eltype(tilearray_type)
64+
memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic RMW requires constant memory_order"))
65+
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic RMW requires constant memory_scope"))
11966

120-
# Get update value
121-
val_tv = emit_value!(ctx, args[3])
122-
val_tv === nothing && throw(IRError("atomic operation requires value"))
67+
shape = ptr_tv.shape
12368

124-
# Get memory order and scope from args
125-
memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic operation requires constant memory_order"))
126-
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic operation requires constant memory_scope"))
69+
# Get element type from pointer tile: Tile{Ptr{T}, S} -> T
70+
ptrs_type = CC.widenconst(ptr_tv.jltype)
71+
ptr_type = eltype(ptrs_type)
72+
elem_type = eltype(ptr_type)
12773

128-
# Create result type (0D tile of element type)
12974
dtype = julia_to_tile_dtype!(tt, elem_type)
130-
result_tile_type = tile_type!(tt, dtype, Int[])
75+
result_tile_type = tile_type!(tt, dtype, collect(shape))
13176
token_type = Token(tt)
13277

133-
# Get index and create pointer type
134-
index_tv = emit_value!(ctx, args[2])
135-
ptr_type = pointer_type!(tt, dtype)
136-
ptr_tile_type = tile_type!(tt, ptr_type, Int[])
137-
138-
# Compute pointer using OffsetOp (handles any integer index type)
139-
pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v)
140-
141-
# Use float add mode for floating point types
78+
# Auto-promote integer ADD to float ADD for floating-point types
14279
actual_mode = mode
14380
if mode == AtomicADD && elem_type <: AbstractFloat
14481
actual_mode = AtomicADDF
14582
end
14683

147-
# Emit atomic RMW
14884
mem_ordering = memory_order_to_semantics(memory_order)
14985
mem_scope = memory_scope_to_scope(memory_scope)
15086

151-
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, pointers,
152-
val_tv.v, actual_mode;
153-
token=ctx.token,
154-
memory_ordering=mem_ordering,
155-
memory_scope=mem_scope)
87+
if has_mask
88+
mask_tv = emit_value!(ctx, args[3])
89+
mask_tv === nothing && throw(IRError("atomic RMW: cannot resolve mask"))
90+
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
91+
ptr_tv.v, val_tv.v, actual_mode;
92+
mask=mask_tv.v,
93+
token=ctx.token,
94+
memory_ordering=mem_ordering,
95+
memory_scope=mem_scope)
96+
else
97+
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
98+
ptr_tv.v, val_tv.v, actual_mode;
99+
token=ctx.token,
100+
memory_ordering=mem_ordering,
101+
memory_scope=mem_scope)
102+
end
156103
ctx.token = new_token
157104

158-
# Return scalar type (not Tile) to match the intrinsic signature
159-
CGVal(old_val, result_tile_type, elem_type, Int[])
160-
end
161-
162-
# cuda_tile.atomic_rmw_tko with XCHG
163-
@intrinsic atomic_xchg(array, index, val, memory_order, memory_scope)
164-
tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
165-
efunc(::typeof(Intrinsics.atomic_xchg), effects::CC.Effects) =
166-
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
167-
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg), args)
168-
emit_atomic_rmw!(ctx, args, AtomicXCHG)
105+
# Return type depends on shape: raw T for 0D, Tile{T, S} for N-D
106+
if isempty(shape)
107+
CGVal(old_val, result_tile_type, elem_type, Int[])
108+
else
109+
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
110+
end
169111
end
170112

171113
# cuda_tile.atomic_rmw_tko with ADD
172-
@intrinsic atomic_add(array, index, val,
173-
memory_order, memory_scope)
174-
tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
114+
@intrinsic atomic_add(ptr_tile, val, mask, memory_order, memory_scope)
115+
tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize args...) = atomic_rmw_tfunc(𝕃, args...)
175116
efunc(::typeof(Intrinsics.atomic_add), effects::CC.Effects) =
176117
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
177118
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add), args)
178119
emit_atomic_rmw!(ctx, args, AtomicADD)
179120
end
180121

181-
# ============================================================================
182-
# Tile-indexed atomic operations
183-
# These take pre-computed pointer tiles, value tiles, and masks.
184-
# Used by the public API for tile-indexed atomic operations.
185-
# ============================================================================
186-
187-
# Shared codegen helper for tile-indexed atomic RMW operations
188-
function emit_atomic_rmw_tile!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
189-
cb = ctx.cb
190-
tt = ctx.tt
191-
192-
# args: (ptr_tile, val, mask, memory_order, memory_scope)
193-
ptr_tv = emit_value!(ctx, args[1])
194-
ptr_tv === nothing && throw(IRError("tile-indexed atomic RMW requires ptr_tile"))
195-
val_tv = emit_value!(ctx, args[2])
196-
val_tv === nothing && throw(IRError("tile-indexed atomic RMW requires value"))
197-
mask_tv = emit_value!(ctx, args[3])
198-
mask_tv === nothing && throw(IRError("tile-indexed atomic RMW requires mask"))
199-
200-
memory_order = @something get_constant(ctx, args[4]) throw(IRError("tile-indexed atomic RMW requires constant memory_order"))
201-
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic RMW requires constant memory_scope"))
202-
203-
shape = val_tv.shape
204-
elem_type = eltype(val_tv.jltype)
205-
206-
dtype = julia_to_tile_dtype!(tt, elem_type)
207-
result_tile_type = tile_type!(tt, dtype, collect(shape))
208-
token_type = Token(tt)
209-
210-
# Auto-promote integer ADD to float ADD for floating-point types
211-
actual_mode = mode
212-
if mode == AtomicADD && elem_type <: AbstractFloat
213-
actual_mode = AtomicADDF
214-
end
215-
216-
mem_ordering = memory_order_to_semantics(memory_order)
217-
mem_scope = memory_scope_to_scope(memory_scope)
218-
219-
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
220-
ptr_tv.v, val_tv.v, actual_mode;
221-
mask=mask_tv.v,
222-
token=ctx.token,
223-
memory_ordering=mem_ordering,
224-
memory_scope=mem_scope)
225-
ctx.token = new_token
226-
227-
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
228-
end
229-
230-
# Tile-indexed atomic exchange
231-
@intrinsic atomic_xchg_tile(ptr_tile, val, mask, memory_order, memory_scope)
232-
function tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...)
233-
CC.widenconst(val)
234-
end
235-
efunc(::typeof(Intrinsics.atomic_xchg_tile), effects::CC.Effects) =
236-
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
237-
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg_tile), args)
238-
emit_atomic_rmw_tile!(ctx, args, AtomicXCHG)
239-
end
240-
241-
# Tile-indexed atomic addition
242-
@intrinsic atomic_add_tile(ptr_tile, val, mask, memory_order, memory_scope)
243-
function tfunc(𝕃, ::typeof(Intrinsics.atomic_add_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...)
244-
CC.widenconst(val)
245-
end
246-
efunc(::typeof(Intrinsics.atomic_add_tile), effects::CC.Effects) =
122+
# cuda_tile.atomic_rmw_tko with XCHG
123+
@intrinsic atomic_xchg(ptr_tile, val, mask, memory_order, memory_scope)
124+
tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize args...) = atomic_rmw_tfunc(𝕃, args...)
125+
efunc(::typeof(Intrinsics.atomic_xchg), effects::CC.Effects) =
247126
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
248-
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add_tile), args)
249-
emit_atomic_rmw_tile!(ctx, args, AtomicADD)
127+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg), args)
128+
emit_atomic_rmw!(ctx, args, AtomicXCHG)
250129
end
251130

252-
# Tile-indexed atomic compare-and-swap
253-
@intrinsic atomic_cas_tile(ptr_tile, expected, desired, mask, memory_order, memory_scope)
254-
function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas_tile), @nospecialize(ptrs), @nospecialize(expected), @nospecialize args...)
255-
CC.widenconst(expected)
131+
# cuda_tile.atomic_cas_tko
132+
@intrinsic atomic_cas(ptr_tile, expected, desired, mask, memory_order, memory_scope)
133+
function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(ptrs), @nospecialize args...)
134+
atomic_rmw_tfunc(𝕃, ptrs, args...)
256135
end
257-
efunc(::typeof(Intrinsics.atomic_cas_tile), effects::CC.Effects) =
136+
efunc(::typeof(Intrinsics.atomic_cas), effects::CC.Effects) =
258137
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
259-
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas_tile), args)
138+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args)
260139
cb = ctx.cb
261140
tt = ctx.tt
262141

263142
# args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
264143
ptr_tv = emit_value!(ctx, args[1])
265-
ptr_tv === nothing && throw(IRError("tile-indexed atomic CAS requires ptr_tile"))
144+
ptr_tv === nothing && throw(IRError("atomic CAS requires ptr_tile"))
266145
expected_tv = emit_value!(ctx, args[2])
267-
expected_tv === nothing && throw(IRError("tile-indexed atomic CAS requires expected value"))
146+
expected_tv === nothing && throw(IRError("atomic CAS requires expected value"))
268147
desired_tv = emit_value!(ctx, args[3])
269-
desired_tv === nothing && throw(IRError("tile-indexed atomic CAS requires desired value"))
270-
mask_tv = emit_value!(ctx, args[4])
271-
mask_tv === nothing && throw(IRError("tile-indexed atomic CAS requires mask"))
148+
desired_tv === nothing && throw(IRError("atomic CAS requires desired value"))
149+
150+
# Check if mask is provided (ghost Nothing = no mask)
151+
has_mask = get_constant(ctx, args[4]) !== nothing
152+
153+
memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic CAS requires constant memory_order"))
154+
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic CAS requires constant memory_scope"))
272155

273-
memory_order = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic CAS requires constant memory_order"))
274-
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("tile-indexed atomic CAS requires constant memory_scope"))
156+
shape = ptr_tv.shape
275157

276-
shape = expected_tv.shape
277-
elem_type = eltype(expected_tv.jltype)
158+
# Get element type from pointer tile: Tile{Ptr{T}, S} -> T
159+
ptrs_type = CC.widenconst(ptr_tv.jltype)
160+
ptr_type = eltype(ptrs_type)
161+
elem_type = eltype(ptr_type)
278162

279163
dtype = julia_to_tile_dtype!(tt, elem_type)
280164
result_tile_type = tile_type!(tt, dtype, collect(shape))
@@ -283,13 +167,28 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas_tile), args)
283167
mem_ordering = memory_order_to_semantics(memory_order)
284168
mem_scope = memory_scope_to_scope(memory_scope)
285169

286-
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
287-
ptr_tv.v, expected_tv.v, desired_tv.v;
288-
mask=mask_tv.v,
289-
token=ctx.token,
290-
memory_ordering=mem_ordering,
291-
memory_scope=mem_scope)
170+
if has_mask
171+
mask_tv = emit_value!(ctx, args[4])
172+
mask_tv === nothing && throw(IRError("atomic CAS: cannot resolve mask"))
173+
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
174+
ptr_tv.v, expected_tv.v, desired_tv.v;
175+
mask=mask_tv.v,
176+
token=ctx.token,
177+
memory_ordering=mem_ordering,
178+
memory_scope=mem_scope)
179+
else
180+
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
181+
ptr_tv.v, expected_tv.v, desired_tv.v;
182+
token=ctx.token,
183+
memory_ordering=mem_ordering,
184+
memory_scope=mem_scope)
185+
end
292186
ctx.token = new_token
293187

294-
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
188+
# Return type depends on shape: raw T for 0D, Tile{T, S} for N-D
189+
if isempty(shape)
190+
CGVal(old_val, result_tile_type, elem_type, Int[])
191+
else
192+
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
193+
end
295194
end

0 commit comments

Comments
 (0)