Skip to content

Commit 54b4872

Browse files
maleadtclaude
andcommitted
Unify scalar and tile-indexed atomic intrinsics
Replace 6 intrinsics (3 scalar + 3 tile) with 3 unified ones that take (ptr_tile, val, mask, ...) matching Python cuTile's design. Both paths now compute pointers via Intrinsics.offset in the language layer, with mask=nothing for scalar indices (no mask in bytecode) and Tile{Bool} for tile-indexed (bounds mask passed through). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d4f322c commit 54b4872

3 files changed

Lines changed: 199 additions & 282 deletions

File tree

src/compiler/intrinsics/atomics.jl

Lines changed: 103 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -30,114 +30,119 @@ function memory_scope_to_scope(scope::Int)
3030
end
3131
end
3232

33+
"""
34+
atomic_tfunc(ptrs) -> Type
35+
36+
Shared tfunc for atomic operations (add, xchg, cas).
37+
Returns raw T for 0D pointer tiles, Tile{T, S} for N-D.
38+
"""
39+
function atomic_tfunc(𝕃, @nospecialize(ptrs), @nospecialize args...)
40+
ptrs_type = CC.widenconst(ptrs)
41+
ptrs_type isa DataType && ptrs_type <: Tile || return nothing
42+
ptr_type = eltype(ptrs_type)
43+
ptr_type <: Ptr || return nothing
44+
T = eltype(ptr_type)
45+
S = ptrs_type.parameters[2]
46+
S === Tuple{} && return T
47+
return Tile{T, S}
48+
end
49+
3350
# cuda_tile.atomic_cas_tko
34-
@intrinsic atomic_cas(array, index, expected, desired,
35-
memory_order, memory_scope)
36-
tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
51+
@intrinsic atomic_cas(ptr_tile, expected, desired, mask, memory_order, memory_scope)
52+
function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(ptrs), @nospecialize args...)
53+
atomic_tfunc(𝕃, ptrs, args...)
54+
end
3755
efunc(::typeof(Intrinsics.atomic_cas), effects::CC.Effects) =
3856
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
3957
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args)
4058
cb = ctx.cb
4159
tt = ctx.tt
4260

43-
# args: (array, index, expected, desired, memory_order, memory_scope)
44-
array_arg = args[1]
45-
46-
# Get array info
47-
arg_idx = extract_argument_index(array_arg)
48-
is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx)
61+
# args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
62+
ptr_tv = emit_value!(ctx, args[1])
63+
ptr_tv === nothing && throw(IRError("atomic CAS requires ptr_tile"))
64+
expected_tv = emit_value!(ctx, args[2])
65+
expected_tv === nothing && throw(IRError("atomic CAS requires expected value"))
66+
desired_tv = emit_value!(ctx, args[3])
67+
desired_tv === nothing && throw(IRError("atomic CAS requires desired value"))
4968

50-
if !is_tilearray
51-
throw(IRError("atomic_cas requires a TileArray argument"))
52-
end
69+
# Check if mask is provided (ghost Nothing = no mask)
70+
has_mask = get_constant(ctx, args[4]) !== nothing
5371

54-
ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr)
55-
isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument"))
56-
array_val = ptr_vals[1]
57-
tilearray_type = get_arg_type(ctx, arg_idx)
58-
elem_type = eltype(tilearray_type)
72+
memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic CAS requires constant memory_order"))
73+
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic CAS requires constant memory_scope"))
5974

60-
# Get expected and desired values
61-
expected_tv = emit_value!(ctx, args[3])
62-
expected_tv === nothing && throw(IRError("atomic_cas requires expected value"))
63-
desired_tv = emit_value!(ctx, args[4])
64-
desired_tv === nothing && throw(IRError("atomic_cas requires desired value"))
75+
shape = ptr_tv.shape
6576

66-
# Get memory order and scope from args
67-
memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic_cas requires constant memory_order"))
68-
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic_cas requires constant memory_scope"))
77+
# Get element type from pointer tile: Tile{Ptr{T}, S} -> T
78+
ptrs_type = CC.widenconst(ptr_tv.jltype)
79+
ptr_type = eltype(ptrs_type)
80+
elem_type = eltype(ptr_type)
6981

70-
# Create result type (0D tile of element type)
7182
dtype = julia_to_tile_dtype!(tt, elem_type)
72-
result_tile_type = tile_type!(tt, dtype, Int[])
83+
result_tile_type = tile_type!(tt, dtype, collect(shape))
7384
token_type = Token(tt)
7485

75-
# Get index and create pointer type
76-
index_tv = emit_value!(ctx, args[2])
77-
ptr_type = pointer_type!(tt, dtype)
78-
ptr_tile_type = tile_type!(tt, ptr_type, Int[])
79-
80-
# Compute pointer using OffsetOp (handles any integer index type)
81-
pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v)
82-
8386
# Emit atomic CAS
8487
mem_ordering = memory_order_to_semantics(memory_order)
8588
mem_scope = memory_scope_to_scope(memory_scope)
8689

87-
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, pointers,
88-
expected_tv.v, desired_tv.v;
89-
token=ctx.token,
90-
memory_ordering=mem_ordering,
91-
memory_scope=mem_scope)
90+
if has_mask
91+
mask_tv = emit_value!(ctx, args[4])
92+
mask_tv === nothing && throw(IRError("atomic CAS: cannot resolve mask"))
93+
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
94+
ptr_tv.v, expected_tv.v, desired_tv.v;
95+
mask=mask_tv.v,
96+
token=ctx.token,
97+
memory_ordering=mem_ordering,
98+
memory_scope=mem_scope)
99+
else
100+
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
101+
ptr_tv.v, expected_tv.v, desired_tv.v;
102+
token=ctx.token,
103+
memory_ordering=mem_ordering,
104+
memory_scope=mem_scope)
105+
end
92106
ctx.token = new_token
93107

94-
# Return scalar type (not Tile) to match the intrinsic signature
95-
CGVal(old_val, result_tile_type, elem_type, Int[])
108+
# Return type depends on shape: raw T for 0D, Tile{T, S} for N-D
109+
if isempty(shape)
110+
CGVal(old_val, result_tile_type, elem_type, Int[])
111+
else
112+
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
113+
end
96114
end
97115

98116
# cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations)
99117
function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
100118
cb = ctx.cb
101119
tt = ctx.tt
102120

103-
# args: (array, index, val, memory_order, memory_scope)
104-
array_arg = args[1]
105-
106-
# Get array info
107-
arg_idx = extract_argument_index(array_arg)
108-
is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx)
121+
# args: (ptr_tile, val, mask, memory_order, memory_scope)
122+
ptr_tv = emit_value!(ctx, args[1])
123+
ptr_tv === nothing && throw(IRError("atomic RMW requires ptr_tile"))
124+
val_tv = emit_value!(ctx, args[2])
125+
val_tv === nothing && throw(IRError("atomic RMW requires value"))
109126

110-
if !is_tilearray
111-
throw(IRError("atomic operations require a TileArray argument"))
112-
end
127+
# Check if mask is provided (ghost Nothing = no mask)
128+
has_mask = get_constant(ctx, args[3]) !== nothing
113129

114-
ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr)
115-
isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument"))
116-
array_val = ptr_vals[1]
117-
tilearray_type = get_arg_type(ctx, arg_idx)
118-
elem_type = eltype(tilearray_type)
130+
# Get memory order and scope from args
131+
memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic RMW requires constant memory_order"))
132+
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic RMW requires constant memory_scope"))
119133

120-
# Get update value
121-
val_tv = emit_value!(ctx, args[3])
122-
val_tv === nothing && throw(IRError("atomic operation requires value"))
134+
shape = ptr_tv.shape
123135

124-
# Get memory order and scope from args
125-
memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic operation requires constant memory_order"))
126-
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic operation requires constant memory_scope"))
136+
# Get element type from pointer tile: Tile{Ptr{T}, S} -> T
137+
ptrs_type = CC.widenconst(ptr_tv.jltype)
138+
ptr_type = eltype(ptrs_type)
139+
elem_type = eltype(ptr_type)
127140

128-
# Create result type (0D tile of element type)
141+
# Create result type
129142
dtype = julia_to_tile_dtype!(tt, elem_type)
130-
result_tile_type = tile_type!(tt, dtype, Int[])
143+
result_tile_type = tile_type!(tt, dtype, collect(shape))
131144
token_type = Token(tt)
132145

133-
# Get index and create pointer type
134-
index_tv = emit_value!(ctx, args[2])
135-
ptr_type = pointer_type!(tt, dtype)
136-
ptr_tile_type = tile_type!(tt, ptr_type, Int[])
137-
138-
# Compute pointer using OffsetOp (handles any integer index type)
139-
pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v)
140-
141146
# Use float add mode for floating point types
142147
actual_mode = mode
143148
if mode == AtomicADD && elem_type <: AbstractFloat
@@ -148,148 +153,46 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
148153
mem_ordering = memory_order_to_semantics(memory_order)
149154
mem_scope = memory_scope_to_scope(memory_scope)
150155

151-
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, pointers,
152-
val_tv.v, actual_mode;
153-
token=ctx.token,
154-
memory_ordering=mem_ordering,
155-
memory_scope=mem_scope)
156+
if has_mask
157+
mask_tv = emit_value!(ctx, args[3])
158+
mask_tv === nothing && throw(IRError("atomic RMW: cannot resolve mask"))
159+
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
160+
ptr_tv.v, val_tv.v, actual_mode;
161+
mask=mask_tv.v,
162+
token=ctx.token,
163+
memory_ordering=mem_ordering,
164+
memory_scope=mem_scope)
165+
else
166+
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
167+
ptr_tv.v, val_tv.v, actual_mode;
168+
token=ctx.token,
169+
memory_ordering=mem_ordering,
170+
memory_scope=mem_scope)
171+
end
156172
ctx.token = new_token
157173

158-
# Return scalar type (not Tile) to match the intrinsic signature
159-
CGVal(old_val, result_tile_type, elem_type, Int[])
174+
# Return type depends on shape: raw T for 0D, Tile{T, S} for N-D
175+
if isempty(shape)
176+
CGVal(old_val, result_tile_type, elem_type, Int[])
177+
else
178+
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
179+
end
160180
end
161181

162182
# cuda_tile.atomic_rmw_tko with XCHG
163-
@intrinsic atomic_xchg(array, index, val, memory_order, memory_scope)
164-
tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
183+
@intrinsic atomic_xchg(ptr_tile, val, mask, memory_order, memory_scope)
184+
tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize args...) = atomic_tfunc(𝕃, args...)
165185
efunc(::typeof(Intrinsics.atomic_xchg), effects::CC.Effects) =
166186
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
167187
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg), args)
168188
emit_atomic_rmw!(ctx, args, AtomicXCHG)
169189
end
170190

171191
# cuda_tile.atomic_rmw_tko with ADD
172-
@intrinsic atomic_add(array, index, val,
173-
memory_order, memory_scope)
174-
tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
192+
@intrinsic atomic_add(ptr_tile, val, mask, memory_order, memory_scope)
193+
tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize args...) = atomic_tfunc(𝕃, args...)
175194
efunc(::typeof(Intrinsics.atomic_add), effects::CC.Effects) =
176195
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
177196
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add), args)
178197
emit_atomic_rmw!(ctx, args, AtomicADD)
179198
end
180-
181-
# ============================================================================
182-
# Tile-indexed atomic operations
183-
# These take pre-computed pointer tiles, value tiles, and masks.
184-
# Used by the public API for tile-indexed atomic operations.
185-
# ============================================================================
186-
187-
# Shared codegen helper for tile-indexed atomic RMW operations
188-
function emit_atomic_rmw_tile!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
189-
cb = ctx.cb
190-
tt = ctx.tt
191-
192-
# args: (ptr_tile, val, mask, memory_order, memory_scope)
193-
ptr_tv = emit_value!(ctx, args[1])
194-
ptr_tv === nothing && throw(IRError("tile-indexed atomic RMW requires ptr_tile"))
195-
val_tv = emit_value!(ctx, args[2])
196-
val_tv === nothing && throw(IRError("tile-indexed atomic RMW requires value"))
197-
mask_tv = emit_value!(ctx, args[3])
198-
mask_tv === nothing && throw(IRError("tile-indexed atomic RMW requires mask"))
199-
200-
memory_order = @something get_constant(ctx, args[4]) throw(IRError("tile-indexed atomic RMW requires constant memory_order"))
201-
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic RMW requires constant memory_scope"))
202-
203-
shape = val_tv.shape
204-
elem_type = eltype(val_tv.jltype)
205-
206-
dtype = julia_to_tile_dtype!(tt, elem_type)
207-
result_tile_type = tile_type!(tt, dtype, collect(shape))
208-
token_type = Token(tt)
209-
210-
# Auto-promote integer ADD to float ADD for floating-point types
211-
actual_mode = mode
212-
if mode == AtomicADD && elem_type <: AbstractFloat
213-
actual_mode = AtomicADDF
214-
end
215-
216-
mem_ordering = memory_order_to_semantics(memory_order)
217-
mem_scope = memory_scope_to_scope(memory_scope)
218-
219-
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
220-
ptr_tv.v, val_tv.v, actual_mode;
221-
mask=mask_tv.v,
222-
token=ctx.token,
223-
memory_ordering=mem_ordering,
224-
memory_scope=mem_scope)
225-
ctx.token = new_token
226-
227-
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
228-
end
229-
230-
# Tile-indexed atomic exchange
231-
@intrinsic atomic_xchg_tile(ptr_tile, val, mask, memory_order, memory_scope)
232-
function tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...)
233-
CC.widenconst(val)
234-
end
235-
efunc(::typeof(Intrinsics.atomic_xchg_tile), effects::CC.Effects) =
236-
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
237-
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg_tile), args)
238-
emit_atomic_rmw_tile!(ctx, args, AtomicXCHG)
239-
end
240-
241-
# Tile-indexed atomic addition
242-
@intrinsic atomic_add_tile(ptr_tile, val, mask, memory_order, memory_scope)
243-
function tfunc(𝕃, ::typeof(Intrinsics.atomic_add_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...)
244-
CC.widenconst(val)
245-
end
246-
efunc(::typeof(Intrinsics.atomic_add_tile), effects::CC.Effects) =
247-
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
248-
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add_tile), args)
249-
emit_atomic_rmw_tile!(ctx, args, AtomicADD)
250-
end
251-
252-
# Tile-indexed atomic compare-and-swap
253-
@intrinsic atomic_cas_tile(ptr_tile, expected, desired, mask, memory_order, memory_scope)
254-
function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas_tile), @nospecialize(ptrs), @nospecialize(expected), @nospecialize args...)
255-
CC.widenconst(expected)
256-
end
257-
efunc(::typeof(Intrinsics.atomic_cas_tile), effects::CC.Effects) =
258-
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
259-
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas_tile), args)
260-
cb = ctx.cb
261-
tt = ctx.tt
262-
263-
# args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
264-
ptr_tv = emit_value!(ctx, args[1])
265-
ptr_tv === nothing && throw(IRError("tile-indexed atomic CAS requires ptr_tile"))
266-
expected_tv = emit_value!(ctx, args[2])
267-
expected_tv === nothing && throw(IRError("tile-indexed atomic CAS requires expected value"))
268-
desired_tv = emit_value!(ctx, args[3])
269-
desired_tv === nothing && throw(IRError("tile-indexed atomic CAS requires desired value"))
270-
mask_tv = emit_value!(ctx, args[4])
271-
mask_tv === nothing && throw(IRError("tile-indexed atomic CAS requires mask"))
272-
273-
memory_order = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic CAS requires constant memory_order"))
274-
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("tile-indexed atomic CAS requires constant memory_scope"))
275-
276-
shape = expected_tv.shape
277-
elem_type = eltype(expected_tv.jltype)
278-
279-
dtype = julia_to_tile_dtype!(tt, elem_type)
280-
result_tile_type = tile_type!(tt, dtype, collect(shape))
281-
token_type = Token(tt)
282-
283-
mem_ordering = memory_order_to_semantics(memory_order)
284-
mem_scope = memory_scope_to_scope(memory_scope)
285-
286-
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
287-
ptr_tv.v, expected_tv.v, desired_tv.v;
288-
mask=mask_tv.v,
289-
token=ctx.token,
290-
memory_ordering=mem_ordering,
291-
memory_scope=mem_scope)
292-
ctx.token = new_token
293-
294-
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
295-
end

0 commit comments

Comments
 (0)