Skip to content

Commit 0b2d7c5

Browse files
AntonOrestenmaleadt
authored andcommitted
Add tile-indexed methods for existing atomic operations
1 parent 0f51b81 commit 0b2d7c5

4 files changed

Lines changed: 469 additions & 0 deletions

File tree

src/compiler/intrinsics/atomics.jl

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,119 @@ efunc(::typeof(Intrinsics.atomic_add), effects::CC.Effects) =
177177
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add), args)
178178
emit_atomic_rmw!(ctx, args, AtomicADD)
179179
end
180+
181+
# ============================================================================
182+
# Tile-indexed atomic operations
183+
# These take pre-computed pointer tiles, value tiles, and masks.
184+
# Used by the public API for tile-indexed atomic operations.
185+
# ============================================================================
186+
187+
# Shared codegen helper for tile-indexed atomic RMW operations
188+
function emit_atomic_rmw_tile!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
189+
cb = ctx.cb
190+
tt = ctx.tt
191+
192+
# args: (ptr_tile, val, mask, memory_order, memory_scope)
193+
ptr_tv = emit_value!(ctx, args[1])
194+
ptr_tv === nothing && throw(IRError("tile-indexed atomic RMW requires ptr_tile"))
195+
val_tv = emit_value!(ctx, args[2])
196+
val_tv === nothing && throw(IRError("tile-indexed atomic RMW requires value"))
197+
mask_tv = emit_value!(ctx, args[3])
198+
mask_tv === nothing && throw(IRError("tile-indexed atomic RMW requires mask"))
199+
200+
memory_order = @something get_constant(ctx, args[4]) throw(IRError("tile-indexed atomic RMW requires constant memory_order"))
201+
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic RMW requires constant memory_scope"))
202+
203+
shape = val_tv.shape
204+
elem_type = eltype(val_tv.jltype)
205+
206+
dtype = julia_to_tile_dtype!(tt, elem_type)
207+
result_tile_type = tile_type!(tt, dtype, collect(shape))
208+
token_type = Token(tt)
209+
210+
# Auto-promote integer ADD to float ADD for floating-point types
211+
actual_mode = mode
212+
if mode == AtomicADD && elem_type <: AbstractFloat
213+
actual_mode = AtomicADDF
214+
end
215+
216+
mem_ordering = memory_order_to_semantics(memory_order)
217+
mem_scope = memory_scope_to_scope(memory_scope)
218+
219+
old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
220+
ptr_tv.v, val_tv.v, actual_mode;
221+
mask=mask_tv.v,
222+
token=ctx.token,
223+
memory_ordering=mem_ordering,
224+
memory_scope=mem_scope)
225+
ctx.token = new_token
226+
227+
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
228+
end
229+
230+
# Tile-indexed atomic exchange
231+
@intrinsic atomic_xchg_tile(ptr_tile, val, mask, memory_order, memory_scope)
232+
function tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...)
233+
CC.widenconst(val)
234+
end
235+
efunc(::typeof(Intrinsics.atomic_xchg_tile), effects::CC.Effects) =
236+
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
237+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg_tile), args)
238+
emit_atomic_rmw_tile!(ctx, args, AtomicXCHG)
239+
end
240+
241+
# Tile-indexed atomic addition
242+
@intrinsic atomic_add_tile(ptr_tile, val, mask, memory_order, memory_scope)
243+
function tfunc(𝕃, ::typeof(Intrinsics.atomic_add_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...)
244+
CC.widenconst(val)
245+
end
246+
efunc(::typeof(Intrinsics.atomic_add_tile), effects::CC.Effects) =
247+
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
248+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add_tile), args)
249+
emit_atomic_rmw_tile!(ctx, args, AtomicADD)
250+
end
251+
252+
# Tile-indexed atomic compare-and-swap
253+
@intrinsic atomic_cas_tile(ptr_tile, expected, desired, mask, memory_order, memory_scope)
254+
function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas_tile), @nospecialize(ptrs), @nospecialize(expected), @nospecialize args...)
255+
CC.widenconst(expected)
256+
end
257+
efunc(::typeof(Intrinsics.atomic_cas_tile), effects::CC.Effects) =
258+
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
259+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas_tile), args)
260+
cb = ctx.cb
261+
tt = ctx.tt
262+
263+
# args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
264+
ptr_tv = emit_value!(ctx, args[1])
265+
ptr_tv === nothing && throw(IRError("tile-indexed atomic CAS requires ptr_tile"))
266+
expected_tv = emit_value!(ctx, args[2])
267+
expected_tv === nothing && throw(IRError("tile-indexed atomic CAS requires expected value"))
268+
desired_tv = emit_value!(ctx, args[3])
269+
desired_tv === nothing && throw(IRError("tile-indexed atomic CAS requires desired value"))
270+
mask_tv = emit_value!(ctx, args[4])
271+
mask_tv === nothing && throw(IRError("tile-indexed atomic CAS requires mask"))
272+
273+
memory_order = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic CAS requires constant memory_order"))
274+
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("tile-indexed atomic CAS requires constant memory_scope"))
275+
276+
shape = expected_tv.shape
277+
elem_type = eltype(expected_tv.jltype)
278+
279+
dtype = julia_to_tile_dtype!(tt, elem_type)
280+
result_tile_type = tile_type!(tt, dtype, collect(shape))
281+
token_type = Token(tt)
282+
283+
mem_ordering = memory_order_to_semantics(memory_order)
284+
mem_scope = memory_scope_to_scope(memory_scope)
285+
286+
old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
287+
ptr_tv.v, expected_tv.v, desired_tv.v;
288+
mask=mask_tv.v,
289+
token=ctx.token,
290+
memory_ordering=mem_ordering,
291+
memory_scope=mem_scope)
292+
ctx.token = new_token
293+
294+
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
295+
end

src/language/atomics.jl

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,150 @@ old_val = ct.atomic_add(counters, idx, Int32(1))
8080
memory_scope::Int=MemScope.Device) where {T}
8181
Intrinsics.atomic_add(array, index - One(), val, memory_order, memory_scope)
8282
end
83+
84+
# ============================================================================
85+
# Tile-indexed atomic operations (scatter-gather style indexing)
86+
# These accept Tile indices to perform atomic operations on multiple elements.
87+
# ============================================================================
88+
89+
# --- Pointer/mask helpers (same pattern as gather/scatter in operations.jl) ---
90+
91+
@inline function _atomic_ptrs_mask(array::TileArray{T, 1}, indices::Tile{I}) where {T, I <: Integer}
92+
indices_0 = indices .- one(I)
93+
indices_i32 = convert(Tile{Int32}, indices_0)
94+
ptr_tile = Intrinsics.offset(array.ptr, indices_i32)
95+
zero_0d = Tile(Int32(0))
96+
size_0d = Tile(size(array, 1))
97+
mask = (indices_i32 .>= zero_0d) .& (indices_i32 .< size_0d)
98+
(ptr_tile, mask, size(indices))
99+
end
100+
101+
@inline function _atomic_ptrs_mask(array::TileArray{T, 2},
102+
indices::Tuple{Tile{I0}, Tile{I1}}) where {T, I0 <: Integer, I1 <: Integer}
103+
idx0_0 = indices[1] .- one(I0)
104+
idx1_0 = indices[2] .- one(I1)
105+
106+
S = broadcast_shape(size(indices[1]), size(indices[2]))
107+
idx0_bc = broadcast_to(idx0_0, S)
108+
idx1_bc = broadcast_to(idx1_0, S)
109+
110+
idx0_i32 = convert(Tile{Int32}, idx0_bc)
111+
idx1_i32 = convert(Tile{Int32}, idx1_bc)
112+
113+
stride0_0d = Tile(array.strides[1])
114+
stride1_0d = Tile(array.strides[2])
115+
stride0 = broadcast_to(stride0_0d, S)
116+
stride1 = broadcast_to(stride1_0d, S)
117+
118+
linear_idx = idx0_i32 .* stride0 + idx1_i32 .* stride1
119+
ptr_tile = Intrinsics.offset(array.ptr, linear_idx)
120+
121+
zero_0d = Tile(Int32(0))
122+
zero_bc = broadcast_to(zero_0d, S)
123+
size0_bc = broadcast_to(Tile(size(array, 1)), S)
124+
size1_bc = broadcast_to(Tile(size(array, 2)), S)
125+
126+
mask0 = (idx0_i32 .>= zero_bc) .& (idx0_i32 .< size0_bc)
127+
mask1 = (idx1_i32 .>= zero_bc) .& (idx1_i32 .< size1_bc)
128+
mask = mask0 .& mask1
129+
130+
(ptr_tile, mask, S)
131+
end
132+
133+
# --- RMW operations (atomic_add, atomic_xchg) ---
134+
135+
const _ATOMIC_RMW_OPS = (
136+
(:add, :atomic_add_tile),
137+
(:xchg, :atomic_xchg_tile),
138+
)
139+
140+
for (op, intrinsic) in _ATOMIC_RMW_OPS
141+
fname = Symbol(:atomic_, op)
142+
143+
# 1D with scalar value
144+
@eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{I}, val::T;
145+
memory_order::Int=MemoryOrder.AcqRel,
146+
memory_scope::Int=MemScope.Device) where {T, I <: Integer}
147+
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
148+
val_tile = broadcast_to(Tile(val), S)
149+
Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope)
150+
end
151+
152+
# 1D with tile value
153+
@eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{I}, val::Tile{T};
154+
memory_order::Int=MemoryOrder.AcqRel,
155+
memory_scope::Int=MemScope.Device) where {T, I <: Integer}
156+
ptr_tile, mask, _ = _atomic_ptrs_mask(array, indices)
157+
Intrinsics.$intrinsic(ptr_tile, val, mask, memory_order, memory_scope)
158+
end
159+
160+
# 2D with scalar value
161+
@eval @inline function $fname(array::TileArray{T, 2},
162+
indices::Tuple{Tile{I0}, Tile{I1}}, val::T;
163+
memory_order::Int=MemoryOrder.AcqRel,
164+
memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer}
165+
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
166+
val_tile = broadcast_to(Tile(val), S)
167+
Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope)
168+
end
169+
170+
# 2D with tile value
171+
@eval @inline function $fname(array::TileArray{T, 2},
172+
indices::Tuple{Tile{I0}, Tile{I1}}, val::Tile{T};
173+
memory_order::Int=MemoryOrder.AcqRel,
174+
memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer}
175+
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
176+
val_bc = broadcast_to(val, S)
177+
Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope)
178+
end
179+
end
180+
181+
# --- CAS operations (separate due to different signature) ---
182+
183+
# 1D with scalar expected/desired
184+
@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{I},
185+
expected::T, desired::T;
186+
memory_order::Int=MemoryOrder.AcqRel,
187+
memory_scope::Int=MemScope.Device) where {T, I <: Integer}
188+
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
189+
expected_tile = broadcast_to(Tile(expected), S)
190+
desired_tile = broadcast_to(Tile(desired), S)
191+
Intrinsics.atomic_cas_tile(ptr_tile, expected_tile, desired_tile, mask,
192+
memory_order, memory_scope)
193+
end
194+
195+
# 1D with tile expected/desired
196+
@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{I},
197+
expected::Tile{T}, desired::Tile{T};
198+
memory_order::Int=MemoryOrder.AcqRel,
199+
memory_scope::Int=MemScope.Device) where {T, I <: Integer}
200+
ptr_tile, mask, _ = _atomic_ptrs_mask(array, indices)
201+
Intrinsics.atomic_cas_tile(ptr_tile, expected, desired, mask,
202+
memory_order, memory_scope)
203+
end
204+
205+
# 2D with scalar expected/desired
206+
@inline function atomic_cas(array::TileArray{T, 2},
207+
indices::Tuple{Tile{I0}, Tile{I1}},
208+
expected::T, desired::T;
209+
memory_order::Int=MemoryOrder.AcqRel,
210+
memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer}
211+
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
212+
expected_tile = broadcast_to(Tile(expected), S)
213+
desired_tile = broadcast_to(Tile(desired), S)
214+
Intrinsics.atomic_cas_tile(ptr_tile, expected_tile, desired_tile, mask,
215+
memory_order, memory_scope)
216+
end
217+
218+
# 2D with tile expected/desired
219+
@inline function atomic_cas(array::TileArray{T, 2},
220+
indices::Tuple{Tile{I0}, Tile{I1}},
221+
expected::Tile{T}, desired::Tile{T};
222+
memory_order::Int=MemoryOrder.AcqRel,
223+
memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer}
224+
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
225+
expected_bc = broadcast_to(expected, S)
226+
desired_bc = broadcast_to(desired, S)
227+
Intrinsics.atomic_cas_tile(ptr_tile, expected_bc, desired_bc, mask,
228+
memory_order, memory_scope)
229+
end

test/codegen/operations.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,64 @@
14181418
end
14191419
end
14201420
end
1421+
1422+
@testset "tile-indexed atomic_cas_tko" begin
1423+
spec = ct.ArraySpec{1}(16, true)
1424+
@test @filecheck begin
1425+
@check_label "entry"
1426+
code_tiled(Tuple{ct.TileArray{Int32,1,spec}}) do arr
1427+
@check "iota"
1428+
indices = ct.arange((16,), Int)
1429+
@check "offset"
1430+
@check "atomic_cas_tko"
1431+
ct.atomic_cas(arr, indices, Int32(0), Int32(1))
1432+
return
1433+
end
1434+
end
1435+
end
1436+
1437+
@testset "tile-indexed atomic_rmw_tko" begin
1438+
spec = ct.ArraySpec{1}(16, true)
1439+
# xchg
1440+
@test @filecheck begin
1441+
@check_label "entry"
1442+
code_tiled(Tuple{ct.TileArray{Int32,1,spec}}) do arr
1443+
@check "iota"
1444+
indices = ct.arange((16,), Int)
1445+
@check "offset"
1446+
@check "atomic_rmw_tko"
1447+
ct.atomic_xchg(arr, indices, Int32(42))
1448+
return
1449+
end
1450+
end
1451+
1452+
# add (integer)
1453+
@test @filecheck begin
1454+
@check_label "entry"
1455+
code_tiled(Tuple{ct.TileArray{Int32,1,spec}}) do arr
1456+
@check "iota"
1457+
indices = ct.arange((16,), Int)
1458+
@check "offset"
1459+
@check "atomic_rmw_tko"
1460+
ct.atomic_add(arr, indices, Int32(1))
1461+
return
1462+
end
1463+
end
1464+
1465+
# add (float)
1466+
spec_f32 = ct.ArraySpec{1}(16, true)
1467+
@test @filecheck begin
1468+
@check_label "entry"
1469+
code_tiled(Tuple{ct.TileArray{Float32,1,spec_f32}}) do arr
1470+
@check "iota"
1471+
indices = ct.arange((16,), Int)
1472+
@check "offset"
1473+
@check "atomic_rmw_tko"
1474+
ct.atomic_add(arr, indices, 1.5f0)
1475+
return
1476+
end
1477+
end
1478+
end
14211479
end
14221480

14231481
#=========================================================================

0 commit comments

Comments
 (0)