Skip to content

Commit d4f322c

Browse files
AntonOrestenmaleadt
authored andcommitted
generalize to N dimensions
1 parent 93e180c commit d4f322c

3 files changed

Lines changed: 100 additions & 82 deletions

File tree

src/language/atomics.jl

Lines changed: 65 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -82,54 +82,48 @@ old_val = ct.atomic_add(counters, idx, Int32(1))
8282
end
8383

8484
# ============================================================================
85-
# Tile-indexed atomic operations (scatter-gather style indexing)
85+
# Tile-indexed atomic operations
8686
# These accept Tile indices to perform atomic operations on multiple elements.
8787
# ============================================================================
8888

89-
# --- Pointer/mask helpers (same pattern as gather/scatter in operations.jl) ---
89+
# --- Pointer/mask helper (N-dimensional) ---
9090

91-
@inline function _atomic_ptrs_mask(array::TileArray{T, 1}, indices::Tile{I}) where {T, I <: Integer}
92-
indices_0 = indices .- one(I)
93-
indices_i32 = convert(Tile{Int32}, indices_0)
94-
ptr_tile = Intrinsics.offset(array.ptr, indices_i32)
95-
zero_0d = Tile(Int32(0))
96-
size_0d = Tile(size(array, 1))
97-
mask = (indices_i32 .>= zero_0d) .& (indices_i32 .< size_0d)
98-
(ptr_tile, mask, size(indices))
99-
end
100-
101-
@inline function _atomic_ptrs_mask(array::TileArray{T, 2},
102-
indices::Tuple{Tile{I0}, Tile{I1}}) where {T, I0 <: Integer, I1 <: Integer}
103-
idx0_0 = indices[1] .- one(I0)
104-
idx1_0 = indices[2] .- one(I1)
91+
@inline function _atomic_ptrs_mask(array::TileArray{T, N},
92+
indices::NTuple{N, Tile{<:Integer}}) where {T, N}
93+
# Convert each index to 0-indexed
94+
indices_0 = ntuple(Val(N)) do d
95+
indices[d] .- one(eltype(indices[d]))
96+
end
10597

106-
S = broadcast_shape(size(indices[1]), size(indices[2]))
107-
idx0_bc = broadcast_to(idx0_0, S)
108-
idx1_bc = broadcast_to(idx1_0, S)
98+
# Broadcast all index tiles to a common shape
99+
S = reduce(broadcast_shape, ntuple(d -> size(indices[d]), Val(N)))
109100

110-
idx0_i32 = convert(Tile{Int32}, idx0_bc)
111-
idx1_i32 = convert(Tile{Int32}, idx1_bc)
101+
# Broadcast and convert to Int32
102+
indices_i32 = ntuple(Val(N)) do d
103+
convert(Tile{Int32}, broadcast_to(indices_0[d], S))
104+
end
112105

113-
stride0_0d = Tile(array.strides[1])
114-
stride1_0d = Tile(array.strides[2])
115-
stride0 = broadcast_to(stride0_0d, S)
116-
stride1 = broadcast_to(stride1_0d, S)
106+
# Linear index: sum(idx[d] * stride[d])
107+
linear_idx = reduce(.+, ntuple(Val(N)) do d
108+
indices_i32[d] .* broadcast_to(Tile(array.strides[d]), S)
109+
end)
117110

118-
linear_idx = idx0_i32 .* stride0 + idx1_i32 .* stride1
119111
ptr_tile = Intrinsics.offset(array.ptr, linear_idx)
120112

121-
zero_0d = Tile(Int32(0))
122-
zero_bc = broadcast_to(zero_0d, S)
123-
size0_bc = broadcast_to(Tile(size(array, 1)), S)
124-
size1_bc = broadcast_to(Tile(size(array, 2)), S)
125-
126-
mask0 = (idx0_i32 .>= zero_bc) .& (idx0_i32 .< size0_bc)
127-
mask1 = (idx1_i32 .>= zero_bc) .& (idx1_i32 .< size1_bc)
128-
mask = mask0 .& mask1
113+
# Bounds mask: 0 <= idx[d] < size[d] for all d
114+
zero_bc = broadcast_to(Tile(Int32(0)), S)
115+
mask = reduce(.&, ntuple(Val(N)) do d
116+
(indices_i32[d] .>= zero_bc) .& (indices_i32[d] .< broadcast_to(Tile(size(array, d)), S))
117+
end)
129118

130119
(ptr_tile, mask, S)
131120
end
132121

122+
# 1D convenience: single Tile -> 1-tuple
123+
@inline function _atomic_ptrs_mask(array::TileArray{T, 1}, indices::Tile{<:Integer}) where {T}
124+
_atomic_ptrs_mask(array, (indices,))
125+
end
126+
133127
# --- RMW operations (atomic_add, atomic_xchg) ---
134128

135129
const _ATOMIC_RMW_OPS = (
@@ -140,90 +134,79 @@ const _ATOMIC_RMW_OPS = (
140134
for (op, intrinsic) in _ATOMIC_RMW_OPS
141135
fname = Symbol(:atomic_, op)
142136

143-
# 1D with scalar value
144-
@eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{I}, val::T;
137+
# N-D with scalar value
138+
@eval @inline function $fname(array::TileArray{T, N},
139+
indices::NTuple{N, Tile{<:Integer}}, val::T;
145140
memory_order::Int=MemoryOrder.AcqRel,
146-
memory_scope::Int=MemScope.Device) where {T, I <: Integer}
141+
memory_scope::Int=MemScope.Device) where {T, N}
147142
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
148143
val_tile = broadcast_to(Tile(val), S)
149144
Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope)
150145
end
151146

152-
# 1D with tile value
153-
@eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{I}, val::Tile{T};
147+
# N-D with tile value
148+
@eval @inline function $fname(array::TileArray{T, N},
149+
indices::NTuple{N, Tile{<:Integer}}, val::Tile{T};
154150
memory_order::Int=MemoryOrder.AcqRel,
155-
memory_scope::Int=MemScope.Device) where {T, I <: Integer}
156-
ptr_tile, mask, _ = _atomic_ptrs_mask(array, indices)
157-
Intrinsics.$intrinsic(ptr_tile, val, mask, memory_order, memory_scope)
151+
memory_scope::Int=MemScope.Device) where {T, N}
152+
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
153+
val_bc = broadcast_to(val, S)
154+
Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope)
158155
end
159156

160-
# 2D with scalar value
161-
@eval @inline function $fname(array::TileArray{T, 2},
162-
indices::Tuple{Tile{I0}, Tile{I1}}, val::T;
157+
# 1D convenience: single Tile index
158+
@eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{<:Integer}, val::T;
163159
memory_order::Int=MemoryOrder.AcqRel,
164-
memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer}
165-
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
166-
val_tile = broadcast_to(Tile(val), S)
167-
Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope)
160+
memory_scope::Int=MemScope.Device) where {T}
161+
$fname(array, (indices,), val; memory_order, memory_scope)
168162
end
169163

170-
# 2D with tile value
171-
@eval @inline function $fname(array::TileArray{T, 2},
172-
indices::Tuple{Tile{I0}, Tile{I1}}, val::Tile{T};
164+
@eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{<:Integer}, val::Tile{T};
173165
memory_order::Int=MemoryOrder.AcqRel,
174-
memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer}
175-
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
176-
val_bc = broadcast_to(val, S)
177-
Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope)
166+
memory_scope::Int=MemScope.Device) where {T}
167+
$fname(array, (indices,), val; memory_order, memory_scope)
178168
end
179169
end
180170

181171
# --- CAS operations (separate due to different signature) ---
182172

183-
# 1D with scalar expected/desired
184-
@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{I},
173+
# N-D with scalar expected/desired
174+
@inline function atomic_cas(array::TileArray{T, N},
175+
indices::NTuple{N, Tile{<:Integer}},
185176
expected::T, desired::T;
186177
memory_order::Int=MemoryOrder.AcqRel,
187-
memory_scope::Int=MemScope.Device) where {T, I <: Integer}
178+
memory_scope::Int=MemScope.Device) where {T, N}
188179
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
189180
expected_tile = broadcast_to(Tile(expected), S)
190181
desired_tile = broadcast_to(Tile(desired), S)
191182
Intrinsics.atomic_cas_tile(ptr_tile, expected_tile, desired_tile, mask,
192183
memory_order, memory_scope)
193184
end
194185

195-
# 1D with tile expected/desired
196-
@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{I},
186+
# N-D with tile expected/desired
187+
@inline function atomic_cas(array::TileArray{T, N},
188+
indices::NTuple{N, Tile{<:Integer}},
197189
expected::Tile{T}, desired::Tile{T};
198190
memory_order::Int=MemoryOrder.AcqRel,
199-
memory_scope::Int=MemScope.Device) where {T, I <: Integer}
200-
ptr_tile, mask, _ = _atomic_ptrs_mask(array, indices)
201-
Intrinsics.atomic_cas_tile(ptr_tile, expected, desired, mask,
191+
memory_scope::Int=MemScope.Device) where {T, N}
192+
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
193+
expected_bc = broadcast_to(expected, S)
194+
desired_bc = broadcast_to(desired, S)
195+
Intrinsics.atomic_cas_tile(ptr_tile, expected_bc, desired_bc, mask,
202196
memory_order, memory_scope)
203197
end
204198

205-
# 2D with scalar expected/desired
206-
@inline function atomic_cas(array::TileArray{T, 2},
207-
indices::Tuple{Tile{I0}, Tile{I1}},
199+
# 1D convenience: single Tile index
200+
@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer},
208201
expected::T, desired::T;
209202
memory_order::Int=MemoryOrder.AcqRel,
210-
memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer}
211-
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
212-
expected_tile = broadcast_to(Tile(expected), S)
213-
desired_tile = broadcast_to(Tile(desired), S)
214-
Intrinsics.atomic_cas_tile(ptr_tile, expected_tile, desired_tile, mask,
215-
memory_order, memory_scope)
203+
memory_scope::Int=MemScope.Device) where {T}
204+
atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope)
216205
end
217206

218-
# 2D with tile expected/desired
219-
@inline function atomic_cas(array::TileArray{T, 2},
220-
indices::Tuple{Tile{I0}, Tile{I1}},
207+
@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer},
221208
expected::Tile{T}, desired::Tile{T};
222209
memory_order::Int=MemoryOrder.AcqRel,
223-
memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer}
224-
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
225-
expected_bc = broadcast_to(expected, S)
226-
desired_bc = broadcast_to(desired, S)
227-
Intrinsics.atomic_cas_tile(ptr_tile, expected_bc, desired_bc, mask,
228-
memory_order, memory_scope)
210+
memory_scope::Int=MemScope.Device) where {T}
211+
atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope)
229212
end

test/codegen/operations.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,23 @@
14341434
end
14351435
end
14361436

1437+
@testset "tile-indexed 3D atomic_add" begin
1438+
spec3d = ct.ArraySpec{3}(16, true)
1439+
@test @filecheck begin
1440+
@check_label "entry"
1441+
code_tiled(Tuple{ct.TileArray{Int32,3,spec3d}}) do arr
1442+
@check "iota"
1443+
i = ct.arange((4,), Int)
1444+
j = ct.arange((4,), Int)
1445+
k = ct.arange((4,), Int)
1446+
@check "offset"
1447+
@check "atomic_rmw_tko"
1448+
ct.atomic_add(arr, (i, j, k), Int32(1))
1449+
return
1450+
end
1451+
end
1452+
end
1453+
14371454
@testset "tile-indexed atomic_rmw_tko" begin
14381455
spec = ct.ArraySpec{1}(16, true)
14391456
# xchg

test/execution/atomics.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,24 @@ end
314314
@test all(Array(arr) .== 1)
315315
end
316316

317+
@testset "atomic_add tile-indexed 3D" begin
318+
function atomic_add_3d_kernel(arr::ct.TileArray{Int,3})
319+
# 3D index tiles — each is length 4, will broadcast to (4,4,4) = 64 elements
320+
i = ct.reshape(ct.arange((4,), Int), (4, 1, 1))
321+
j = ct.reshape(ct.arange((4,), Int), (1, 4, 1))
322+
k = ct.reshape(ct.arange((4,), Int), (1, 1, 4))
323+
ct.atomic_add(arr, (i, j, k), 1;
324+
memory_order=ct.MemoryOrder.AcqRel)
325+
return
326+
end
327+
328+
arr = CUDA.zeros(Int, 4, 4, 4)
329+
330+
ct.launch(atomic_add_3d_kernel, 1, arr)
331+
332+
@test all(Array(arr) .== 1)
333+
end
334+
317335
@testset "1D gather - simple" begin
318336
# Simple 1D gather: copy first 16 elements using gather
319337
function gather_simple_kernel(src::ct.TileArray{Float32,1}, dst::ct.TileArray{Float32,1})

0 commit comments

Comments
 (0)