@@ -82,54 +82,48 @@ old_val = ct.atomic_add(counters, idx, Int32(1))
8282end
8383
8484# ============================================================================
85- # Tile-indexed atomic operations (scatter-gather style indexing)
85+ # Tile-indexed atomic operations
8686# These accept Tile indices to perform atomic operations on multiple elements.
8787# ============================================================================
8888
89- # --- Pointer/mask helpers (same pattern as gather/scatter in operations.jl ) ---
89+ # --- Pointer/mask helper (N-dimensional ) ---
9090
91- @inline function _atomic_ptrs_mask (array:: TileArray{T, 1} , indices:: Tile{I} ) where {T, I <: Integer }
92- indices_0 = indices .- one (I)
93- indices_i32 = convert (Tile{Int32}, indices_0)
94- ptr_tile = Intrinsics. offset (array. ptr, indices_i32)
95- zero_0d = Tile (Int32 (0 ))
96- size_0d = Tile (size (array, 1 ))
97- mask = (indices_i32 .>= zero_0d) .& (indices_i32 .< size_0d)
98- (ptr_tile, mask, size (indices))
99- end
100-
101- @inline function _atomic_ptrs_mask (array:: TileArray{T, 2} ,
102- indices:: Tuple{Tile{I0}, Tile{I1}} ) where {T, I0 <: Integer , I1 <: Integer }
103- idx0_0 = indices[1 ] .- one (I0)
104- idx1_0 = indices[2 ] .- one (I1)
91+ @inline function _atomic_ptrs_mask (array:: TileArray{T, N} ,
92+ indices:: NTuple{N, Tile{<:Integer}} ) where {T, N}
93+ # Convert each index to 0-indexed
94+ indices_0 = ntuple (Val (N)) do d
95+ indices[d] .- one (eltype (indices[d]))
96+ end
10597
106- S = broadcast_shape (size (indices[1 ]), size (indices[2 ]))
107- idx0_bc = broadcast_to (idx0_0, S)
108- idx1_bc = broadcast_to (idx1_0, S)
98+ # Broadcast all index tiles to a common shape
99+ S = reduce (broadcast_shape, ntuple (d -> size (indices[d]), Val (N)))
109100
110- idx0_i32 = convert (Tile{Int32}, idx0_bc)
111- idx1_i32 = convert (Tile{Int32}, idx1_bc)
101+ # Broadcast and convert to Int32
102+ indices_i32 = ntuple (Val (N)) do d
103+ convert (Tile{Int32}, broadcast_to (indices_0[d], S))
104+ end
112105
113- stride0_0d = Tile (array . strides[ 1 ])
114- stride1_0d = Tile (array . strides[ 2 ])
115- stride0 = broadcast_to (stride0_0d , S)
116- stride1 = broadcast_to (stride1_0d, S )
106+ # Linear index: sum(idx[d] * stride[d ])
107+ linear_idx = reduce ( .+ , ntuple ( Val (N)) do d
108+ indices_i32[d] .* broadcast_to (Tile (array . strides[d]) , S)
109+ end )
117110
118- linear_idx = idx0_i32 .* stride0 + idx1_i32 .* stride1
119111 ptr_tile = Intrinsics. offset (array. ptr, linear_idx)
120112
121- zero_0d = Tile (Int32 (0 ))
122- zero_bc = broadcast_to (zero_0d, S)
123- size0_bc = broadcast_to (Tile (size (array, 1 )), S)
124- size1_bc = broadcast_to (Tile (size (array, 2 )), S)
125-
126- mask0 = (idx0_i32 .>= zero_bc) .& (idx0_i32 .< size0_bc)
127- mask1 = (idx1_i32 .>= zero_bc) .& (idx1_i32 .< size1_bc)
128- mask = mask0 .& mask1
113+ # Bounds mask: 0 <= idx[d] < size[d] for all d
114+ zero_bc = broadcast_to (Tile (Int32 (0 )), S)
115+ mask = reduce (.& , ntuple (Val (N)) do d
116+ (indices_i32[d] .>= zero_bc) .& (indices_i32[d] .< broadcast_to (Tile (size (array, d)), S))
117+ end )
129118
130119 (ptr_tile, mask, S)
131120end
132121
122+ # 1D convenience: single Tile -> 1-tuple
123+ @inline function _atomic_ptrs_mask (array:: TileArray{T, 1} , indices:: Tile{<:Integer} ) where {T}
124+ _atomic_ptrs_mask (array, (indices,))
125+ end
126+
133127# --- RMW operations (atomic_add, atomic_xchg) ---
134128
135129const _ATOMIC_RMW_OPS = (
@@ -140,90 +134,79 @@ const _ATOMIC_RMW_OPS = (
140134for (op, intrinsic) in _ATOMIC_RMW_OPS
141135 fname = Symbol (:atomic_ , op)
142136
143- # 1D with scalar value
144- @eval @inline function $fname (array:: TileArray{T, 1} , indices:: Tile{I} , val:: T ;
137+ # N-D with scalar value
138+ @eval @inline function $fname (array:: TileArray{T, N} ,
139+ indices:: NTuple{N, Tile{<:Integer}} , val:: T ;
145140 memory_order:: Int = MemoryOrder. AcqRel,
146- memory_scope:: Int = MemScope. Device) where {T, I <: Integer }
141+ memory_scope:: Int = MemScope. Device) where {T, N }
147142 ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
148143 val_tile = broadcast_to (Tile (val), S)
149144 Intrinsics.$ intrinsic (ptr_tile, val_tile, mask, memory_order, memory_scope)
150145 end
151146
152- # 1D with tile value
153- @eval @inline function $fname (array:: TileArray{T, 1} , indices:: Tile{I} , val:: Tile{T} ;
147+ # N-D with tile value
148+ @eval @inline function $fname (array:: TileArray{T, N} ,
149+ indices:: NTuple{N, Tile{<:Integer}} , val:: Tile{T} ;
154150 memory_order:: Int = MemoryOrder. AcqRel,
155- memory_scope:: Int = MemScope. Device) where {T, I <: Integer }
156- ptr_tile, mask, _ = _atomic_ptrs_mask (array, indices)
157- Intrinsics.$ intrinsic (ptr_tile, val, mask, memory_order, memory_scope)
151+ memory_scope:: Int = MemScope. Device) where {T, N}
152+ ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
153+ val_bc = broadcast_to (val, S)
154+ Intrinsics.$ intrinsic (ptr_tile, val_bc, mask, memory_order, memory_scope)
158155 end
159156
160- # 2D with scalar value
161- @eval @inline function $fname (array:: TileArray{T, 2} ,
162- indices:: Tuple{Tile{I0}, Tile{I1}} , val:: T ;
157+ # 1D convenience: single Tile index
158+ @eval @inline function $fname (array:: TileArray{T, 1} , indices:: Tile{<:Integer} , val:: T ;
163159 memory_order:: Int = MemoryOrder. AcqRel,
164- memory_scope:: Int = MemScope. Device) where {T, I0 <: Integer , I1 <: Integer }
165- ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
166- val_tile = broadcast_to (Tile (val), S)
167- Intrinsics.$ intrinsic (ptr_tile, val_tile, mask, memory_order, memory_scope)
160+ memory_scope:: Int = MemScope. Device) where {T}
161+ $ fname (array, (indices,), val; memory_order, memory_scope)
168162 end
169163
170- # 2D with tile value
171- @eval @inline function $fname (array:: TileArray{T, 2} ,
172- indices:: Tuple{Tile{I0}, Tile{I1}} , val:: Tile{T} ;
164+ @eval @inline function $fname (array:: TileArray{T, 1} , indices:: Tile{<:Integer} , val:: Tile{T} ;
173165 memory_order:: Int = MemoryOrder. AcqRel,
174- memory_scope:: Int = MemScope. Device) where {T, I0 <: Integer , I1 <: Integer }
175- ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
176- val_bc = broadcast_to (val, S)
177- Intrinsics.$ intrinsic (ptr_tile, val_bc, mask, memory_order, memory_scope)
166+ memory_scope:: Int = MemScope. Device) where {T}
167+ $ fname (array, (indices,), val; memory_order, memory_scope)
178168 end
179169end
180170
181171# --- CAS operations (separate due to different signature) ---
182172
183- # 1D with scalar expected/desired
184- @inline function atomic_cas (array:: TileArray{T, 1} , indices:: Tile{I} ,
173+ # N-D with scalar expected/desired
174+ @inline function atomic_cas (array:: TileArray{T, N} ,
175+ indices:: NTuple{N, Tile{<:Integer}} ,
185176 expected:: T , desired:: T ;
186177 memory_order:: Int = MemoryOrder. AcqRel,
187- memory_scope:: Int = MemScope. Device) where {T, I <: Integer }
178+ memory_scope:: Int = MemScope. Device) where {T, N }
188179 ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
189180 expected_tile = broadcast_to (Tile (expected), S)
190181 desired_tile = broadcast_to (Tile (desired), S)
191182 Intrinsics. atomic_cas_tile (ptr_tile, expected_tile, desired_tile, mask,
192183 memory_order, memory_scope)
193184end
194185
195- # 1D with tile expected/desired
196- @inline function atomic_cas (array:: TileArray{T, 1} , indices:: Tile{I} ,
186+ # N-D with tile expected/desired
187+ @inline function atomic_cas (array:: TileArray{T, N} ,
188+ indices:: NTuple{N, Tile{<:Integer}} ,
197189 expected:: Tile{T} , desired:: Tile{T} ;
198190 memory_order:: Int = MemoryOrder. AcqRel,
199- memory_scope:: Int = MemScope. Device) where {T, I <: Integer }
200- ptr_tile, mask, _ = _atomic_ptrs_mask (array, indices)
201- Intrinsics. atomic_cas_tile (ptr_tile, expected, desired, mask,
191+ memory_scope:: Int = MemScope. Device) where {T, N}
192+ ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
193+ expected_bc = broadcast_to (expected, S)
194+ desired_bc = broadcast_to (desired, S)
195+ Intrinsics. atomic_cas_tile (ptr_tile, expected_bc, desired_bc, mask,
202196 memory_order, memory_scope)
203197end
204198
205- # 2D with scalar expected/desired
206- @inline function atomic_cas (array:: TileArray{T, 2} ,
207- indices:: Tuple{Tile{I0}, Tile{I1}} ,
199+ # 1D convenience: single Tile index
200+ @inline function atomic_cas (array:: TileArray{T, 1} , indices:: Tile{<:Integer} ,
208201 expected:: T , desired:: T ;
209202 memory_order:: Int = MemoryOrder. AcqRel,
210- memory_scope:: Int = MemScope. Device) where {T, I0 <: Integer , I1 <: Integer }
211- ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
212- expected_tile = broadcast_to (Tile (expected), S)
213- desired_tile = broadcast_to (Tile (desired), S)
214- Intrinsics. atomic_cas_tile (ptr_tile, expected_tile, desired_tile, mask,
215- memory_order, memory_scope)
203+ memory_scope:: Int = MemScope. Device) where {T}
204+ atomic_cas (array, (indices,), expected, desired; memory_order, memory_scope)
216205end
217206
218- # 2D with tile expected/desired
219- @inline function atomic_cas (array:: TileArray{T, 2} ,
220- indices:: Tuple{Tile{I0}, Tile{I1}} ,
207+ @inline function atomic_cas (array:: TileArray{T, 1} , indices:: Tile{<:Integer} ,
221208 expected:: Tile{T} , desired:: Tile{T} ;
222209 memory_order:: Int = MemoryOrder. AcqRel,
223- memory_scope:: Int = MemScope. Device) where {T, I0 <: Integer , I1 <: Integer }
224- ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
225- expected_bc = broadcast_to (expected, S)
226- desired_bc = broadcast_to (desired, S)
227- Intrinsics. atomic_cas_tile (ptr_tile, expected_bc, desired_bc, mask,
228- memory_order, memory_scope)
210+ memory_scope:: Int = MemScope. Device) where {T}
211+ atomic_cas (array, (indices,), expected, desired; memory_order, memory_scope)
229212end
0 commit comments