Skip to content

Commit 3abbcc1

Browse files
AntonOrestenmaleadtclaude
authored
Support tile-indexed atomic operations (#96)
Co-authored-by: Tim Besard <tim.besard@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 19e3eb3 commit 3abbcc1

5 files changed

Lines changed: 437 additions & 107 deletions

File tree

src/compiler/intrinsics/atomics.jl

Lines changed: 92 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -30,114 +30,113 @@ function memory_scope_to_scope(scope::Int)
3030
end
3131
end
3232

33+
"""
34+
atomic_tfunc(ptrs) -> Type
35+
36+
Shared tfunc for atomic operations (add, xchg, cas).
37+
Always returns Tile{T, S}, even for 0D (S = Tuple{}).
38+
"""
39+
function atomic_tfunc(𝕃, @nospecialize(ptrs), @nospecialize args...)
40+
ptrs_type = CC.widenconst(ptrs)
41+
ptrs_type isa DataType && ptrs_type <: Tile || return nothing
42+
ptr_type = eltype(ptrs_type)
43+
ptr_type <: Ptr || return nothing
44+
T = eltype(ptr_type)
45+
S = ptrs_type.parameters[2]
46+
return Tile{T, S}
47+
end
48+
3349
# cuda_tile.atomic_cas_tko
34-
@intrinsic atomic_cas(array, index, expected, desired,
35-
memory_order, memory_scope)
36-
tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
50+
@intrinsic atomic_cas(ptr_tile, expected, desired, mask, memory_order, memory_scope)
51+
function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(ptrs), @nospecialize args...)
52+
atomic_tfunc(𝕃, ptrs, args...)
53+
end
3754
efunc(::typeof(Intrinsics.atomic_cas), effects::CC.Effects) =
3855
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
3956
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args)
4057
cb = ctx.cb
4158
tt = ctx.tt
4259

43-
# args: (array, index, expected, desired, memory_order, memory_scope)
44-
array_arg = args[1]
60+
# args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
61+
ptr_tv = emit_value!(ctx, args[1])
62+
ptr_tv === nothing && throw(IRError("atomic CAS requires ptr_tile"))
63+
expected_tv = emit_value!(ctx, args[2])
64+
expected_tv === nothing && throw(IRError("atomic CAS requires expected value"))
65+
desired_tv = emit_value!(ctx, args[3])
66+
desired_tv === nothing && throw(IRError("atomic CAS requires desired value"))
4567

46-
# Get array info
47-
arg_idx = extract_argument_index(array_arg)
48-
is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx)
68+
# Check if mask is provided (ghost Nothing = no mask)
69+
has_mask = get_constant(ctx, args[4]) !== nothing
4970

50-
if !is_tilearray
51-
throw(IRError("atomic_cas requires a TileArray argument"))
52-
end
71+
memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic CAS requires constant memory_order"))
72+
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic CAS requires constant memory_scope"))
5373

54-
ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr)
55-
isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument"))
56-
array_val = ptr_vals[1]
57-
tilearray_type = get_arg_type(ctx, arg_idx)
58-
elem_type = eltype(tilearray_type)
74+
shape = ptr_tv.shape
5975

60-
# Get expected and desired values
61-
expected_tv = emit_value!(ctx, args[3])
62-
expected_tv === nothing && throw(IRError("atomic_cas requires expected value"))
63-
desired_tv = emit_value!(ctx, args[4])
64-
desired_tv === nothing && throw(IRError("atomic_cas requires desired value"))
76+
# Get element type from pointer tile: Tile{Ptr{T}, S} -> T
77+
ptrs_type = CC.widenconst(ptr_tv.jltype)
78+
ptr_type = eltype(ptrs_type)
79+
elem_type = eltype(ptr_type)
6580

66-
# Get memory order and scope from args
67-
memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic_cas requires constant memory_order"))
68-
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic_cas requires constant memory_scope"))
69-
70-
# Create result type (0D tile of element type)
7181
dtype = julia_to_tile_dtype!(tt, elem_type)
72-
result_tile_type = tile_type!(tt, dtype, Int[])
82+
result_tile_type = tile_type!(tt, dtype, collect(shape))
7383
token_type = Token(tt)
7484

75-
# Get index and create pointer type
76-
index_tv = emit_value!(ctx, args[2])
77-
ptr_type = pointer_type!(tt, dtype)
78-
ptr_tile_type = tile_type!(tt, ptr_type, Int[])
79-
80-
# Compute pointer using OffsetOp (handles any integer index type)
81-
pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v)
82-
8385
# Emit atomic CAS
8486
mem_ordering = memory_order_to_semantics(memory_order)
8587
mem_scope = memory_scope_to_scope(memory_scope)
8688

87-
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, pointers,
88-
expected_tv.v, desired_tv.v;
89-
token=ctx.token,
90-
memory_ordering=mem_ordering,
91-
memory_scope=mem_scope)
89+
old_val, new_token = if has_mask
90+
mask_tv = emit_value!(ctx, args[4])
91+
mask_tv === nothing && throw(IRError("atomic CAS: cannot resolve mask"))
92+
encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
93+
ptr_tv.v, expected_tv.v, desired_tv.v;
94+
mask=mask_tv.v,
95+
token=ctx.token,
96+
memory_ordering=mem_ordering,
97+
memory_scope=mem_scope)
98+
else
99+
encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
100+
ptr_tv.v, expected_tv.v, desired_tv.v;
101+
token=ctx.token,
102+
memory_ordering=mem_ordering,
103+
memory_scope=mem_scope)
104+
end
92105
ctx.token = new_token
93106

94-
# Return scalar type (not Tile) to match the intrinsic signature
95-
CGVal(old_val, result_tile_type, elem_type, Int[])
107+
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
96108
end
97109

98110
# cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations)
99111
function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
100112
cb = ctx.cb
101113
tt = ctx.tt
102114

103-
# args: (array, index, val, memory_order, memory_scope)
104-
array_arg = args[1]
115+
# args: (ptr_tile, val, mask, memory_order, memory_scope)
116+
ptr_tv = emit_value!(ctx, args[1])
117+
ptr_tv === nothing && throw(IRError("atomic RMW requires ptr_tile"))
118+
val_tv = emit_value!(ctx, args[2])
119+
val_tv === nothing && throw(IRError("atomic RMW requires value"))
105120

106-
# Get array info
107-
arg_idx = extract_argument_index(array_arg)
108-
is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx)
121+
# Check if mask is provided (ghost Nothing = no mask)
122+
has_mask = get_constant(ctx, args[3]) !== nothing
109123

110-
if !is_tilearray
111-
throw(IRError("atomic operations require a TileArray argument"))
112-
end
124+
# Get memory order and scope from args
125+
memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic RMW requires constant memory_order"))
126+
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic RMW requires constant memory_scope"))
113127

114-
ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr)
115-
isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument"))
116-
array_val = ptr_vals[1]
117-
tilearray_type = get_arg_type(ctx, arg_idx)
118-
elem_type = eltype(tilearray_type)
128+
shape = ptr_tv.shape
119129

120-
# Get update value
121-
val_tv = emit_value!(ctx, args[3])
122-
val_tv === nothing && throw(IRError("atomic operation requires value"))
130+
# Get element type from pointer tile: Tile{Ptr{T}, S} -> T
131+
ptrs_type = CC.widenconst(ptr_tv.jltype)
132+
ptr_type = eltype(ptrs_type)
133+
elem_type = eltype(ptr_type)
123134

124-
# Get memory order and scope from args
125-
memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic operation requires constant memory_order"))
126-
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic operation requires constant memory_scope"))
127-
128-
# Create result type (0D tile of element type)
135+
# Create result type
129136
dtype = julia_to_tile_dtype!(tt, elem_type)
130-
result_tile_type = tile_type!(tt, dtype, Int[])
137+
result_tile_type = tile_type!(tt, dtype, collect(shape))
131138
token_type = Token(tt)
132139

133-
# Get index and create pointer type
134-
index_tv = emit_value!(ctx, args[2])
135-
ptr_type = pointer_type!(tt, dtype)
136-
ptr_tile_type = tile_type!(tt, ptr_type, Int[])
137-
138-
# Compute pointer using OffsetOp (handles any integer index type)
139-
pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v)
140-
141140
# Use float add mode for floating point types
142141
actual_mode = mode
143142
if mode == AtomicADD && elem_type <: AbstractFloat
@@ -148,30 +147,39 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
148147
mem_ordering = memory_order_to_semantics(memory_order)
149148
mem_scope = memory_scope_to_scope(memory_scope)
150149

151-
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, pointers,
152-
val_tv.v, actual_mode;
153-
token=ctx.token,
154-
memory_ordering=mem_ordering,
155-
memory_scope=mem_scope)
150+
old_val, new_token = if has_mask
151+
mask_tv = emit_value!(ctx, args[3])
152+
mask_tv === nothing && throw(IRError("atomic RMW: cannot resolve mask"))
153+
encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
154+
ptr_tv.v, val_tv.v, actual_mode;
155+
mask=mask_tv.v,
156+
token=ctx.token,
157+
memory_ordering=mem_ordering,
158+
memory_scope=mem_scope)
159+
else
160+
encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
161+
ptr_tv.v, val_tv.v, actual_mode;
162+
token=ctx.token,
163+
memory_ordering=mem_ordering,
164+
memory_scope=mem_scope)
165+
end
156166
ctx.token = new_token
157167

158-
# Return scalar type (not Tile) to match the intrinsic signature
159-
CGVal(old_val, result_tile_type, elem_type, Int[])
168+
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
160169
end
161170

162171
# cuda_tile.atomic_rmw_tko with XCHG
163-
@intrinsic atomic_xchg(array, index, val, memory_order, memory_scope)
164-
tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
172+
@intrinsic atomic_xchg(ptr_tile, val, mask, memory_order, memory_scope)
173+
tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize args...) = atomic_tfunc(𝕃, args...)
165174
efunc(::typeof(Intrinsics.atomic_xchg), effects::CC.Effects) =
166175
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
167176
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg), args)
168177
emit_atomic_rmw!(ctx, args, AtomicXCHG)
169178
end
170179

171180
# cuda_tile.atomic_rmw_tko with ADD
172-
@intrinsic atomic_add(array, index, val,
173-
memory_order, memory_scope)
174-
tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
181+
@intrinsic atomic_add(ptr_tile, val, mask, memory_order, memory_scope)
182+
tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize args...) = atomic_tfunc(𝕃, args...)
175183
efunc(::typeof(Intrinsics.atomic_add), effects::CC.Effects) =
176184
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
177185
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add), args)

src/language/atomics.jl

Lines changed: 89 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,61 @@ module MemScope
2525
const System = 2
2626
end
2727

28+
# ============================================================================
29+
# Pointer/mask helpers
30+
#
31+
# Both scalar and tile-indexed paths compute (ptr_tile, mask, shape) here,
32+
# then pass to a single set of intrinsics.
33+
# ============================================================================
34+
35+
# Scalar index -> 0D pointer tile, no mask
36+
@inline function _atomic_ptr_and_mask(array::TileArray{T}, index::Integer) where {T}
37+
idx_0 = Tile(Int32(index - One()))
38+
ptr_tile = Intrinsics.offset(array.ptr, idx_0)
39+
(ptr_tile, nothing, ())
40+
end
41+
42+
# N-D tile indices -> N-D pointer tile with bounds mask
43+
@inline function _atomic_ptr_and_mask(array::TileArray{T, N},
44+
indices::NTuple{N, Tile{<:Integer}}) where {T, N}
45+
# Convert each index to 0-indexed
46+
indices_0 = ntuple(Val(N)) do d
47+
indices[d] .- one(eltype(indices[d]))
48+
end
49+
50+
# Broadcast all index tiles to a common shape
51+
S = reduce(broadcast_shape, ntuple(d -> size(indices[d]), Val(N)))
52+
53+
# Broadcast and convert to Int32
54+
indices_i32 = ntuple(Val(N)) do d
55+
convert(Tile{Int32}, broadcast_to(indices_0[d], S))
56+
end
57+
58+
# Linear index: sum(idx[d] * stride[d])
59+
linear_idx = reduce(.+, ntuple(Val(N)) do d
60+
indices_i32[d] .* broadcast_to(Tile(array.strides[d]), S)
61+
end)
62+
63+
ptr_tile = Intrinsics.offset(array.ptr, linear_idx)
64+
65+
# Bounds mask: 0 <= idx[d] < size[d] for all d
66+
zero_bc = broadcast_to(Tile(Int32(0)), S)
67+
mask = reduce(.&, ntuple(Val(N)) do d
68+
(indices_i32[d] .>= zero_bc) .& (indices_i32[d] .< broadcast_to(Tile(size(array, d)), S))
69+
end)
70+
71+
(ptr_tile, mask, S)
72+
end
73+
74+
# 1D convenience: single Tile -> 1-tuple
75+
@inline function _atomic_ptr_and_mask(array::TileArray{T, 1}, indices::Tile{<:Integer}) where {T}
76+
_atomic_ptr_and_mask(array, (indices,))
77+
end
78+
79+
# ============================================================================
80+
# Atomic CAS
81+
# ============================================================================
82+
2883
"""
2984
atomic_cas(array::TileArray, index, expected, desired; memory_order, memory_scope) -> T
3085
@@ -40,43 +95,59 @@ while ct.atomic_cas(locks, idx, Int32(0), Int32(1); memory_order=ct.MemoryOrder.
4095
end
4196
```
4297
"""
43-
@inline function atomic_cas(array::TileArray{T}, index, expected::T, desired::T;
98+
@inline function atomic_cas(array::TileArray{T}, indices,
99+
expected::TileOrScalar{T}, desired::TileOrScalar{T};
44100
memory_order::Int=MemoryOrder.AcqRel,
45101
memory_scope::Int=MemScope.Device) where {T}
46-
Intrinsics.atomic_cas(array, index - One(), expected, desired, memory_order, memory_scope)
102+
ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices)
103+
expected_bc = S === () ? Tile(expected) : broadcast_to(Tile(expected), S)
104+
desired_bc = S === () ? Tile(desired) : broadcast_to(Tile(desired), S)
105+
result = Intrinsics.atomic_cas(ptr_tile, expected_bc, desired_bc, mask,
106+
memory_order, memory_scope)
107+
S === () ? Intrinsics.to_scalar(result) : result
47108
end
48109

110+
# ============================================================================
111+
# Atomic RMW operations (atomic_add, atomic_xchg)
112+
# ============================================================================
113+
49114
"""
50-
atomic_xchg(array::TileArray, index, val; memory_order, memory_scope) -> T
115+
atomic_add(array::TileArray, index, val; memory_order, memory_scope) -> T
51116
52-
Atomic exchange. Atomically replaces the value at `index` with `val` and returns
117+
Atomic addition. Atomically adds `val` to the value at `index` and returns
53118
the original value. Index is 1-indexed.
54119
55120
# Example
56121
```julia
57-
# Spin-lock release
58-
ct.atomic_xchg(locks, idx, Int32(0); memory_order=ct.MemoryOrder.Release)
122+
old_val = ct.atomic_add(counters, idx, Int32(1))
59123
```
60124
"""
61-
@inline function atomic_xchg(array::TileArray{T}, index, val::T;
62-
memory_order::Int=MemoryOrder.AcqRel,
63-
memory_scope::Int=MemScope.Device) where {T}
64-
Intrinsics.atomic_xchg(array, index - One(), val, memory_order, memory_scope)
65-
end
125+
function atomic_add end
66126

67127
"""
68-
atomic_add(array::TileArray, index, val; memory_order, memory_scope) -> T
128+
atomic_xchg(array::TileArray, index, val; memory_order, memory_scope) -> T
69129
70-
Atomic addition. Atomically adds `val` to the value at `index` and returns
130+
Atomic exchange. Atomically replaces the value at `index` with `val` and returns
71131
the original value. Index is 1-indexed.
72132
73133
# Example
74134
```julia
75-
old_val = ct.atomic_add(counters, idx, Int32(1))
135+
# Spin-lock release
136+
ct.atomic_xchg(locks, idx, Int32(0); memory_order=ct.MemoryOrder.Release)
76137
```
77138
"""
78-
@inline function atomic_add(array::TileArray{T}, index, val::T;
79-
memory_order::Int=MemoryOrder.AcqRel,
80-
memory_scope::Int=MemScope.Device) where {T}
81-
Intrinsics.atomic_add(array, index - One(), val, memory_order, memory_scope)
139+
function atomic_xchg end
140+
141+
for op in (:add, :xchg)
142+
fname = Symbol(:atomic_, op)
143+
intrinsic = Symbol(:atomic_, op)
144+
145+
@eval @inline function $fname(array::TileArray{T}, indices, val::TileOrScalar{T};
146+
memory_order::Int=MemoryOrder.AcqRel,
147+
memory_scope::Int=MemScope.Device) where {T}
148+
ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices)
149+
val_bc = S === () ? Tile(val) : broadcast_to(Tile(val), S)
150+
result = Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope)
151+
S === () ? Intrinsics.to_scalar(result) : result
152+
end
82153
end

0 commit comments

Comments
 (0)