@@ -80,3 +80,150 @@ old_val = ct.atomic_add(counters, idx, Int32(1))
8080 memory_scope:: Int = MemScope. Device) where {T}
8181 Intrinsics. atomic_add (array, index - One (), val, memory_order, memory_scope)
8282end
83+
84+ # ============================================================================
85+ # Tile-indexed atomic operations (scatter-gather style indexing)
86+ # These accept Tile indices to perform atomic operations on multiple elements.
87+ # ============================================================================
88+
89+ # --- Pointer/mask helpers (same pattern as gather/scatter in operations.jl) ---
90+
91+ @inline function _atomic_ptrs_mask (array:: TileArray{T, 1} , indices:: Tile{I} ) where {T, I <: Integer }
92+ indices_0 = indices .- one (I)
93+ indices_i32 = convert (Tile{Int32}, indices_0)
94+ ptr_tile = Intrinsics. offset (array. ptr, indices_i32)
95+ zero_0d = Tile (Int32 (0 ))
96+ size_0d = Tile (size (array, 1 ))
97+ mask = (indices_i32 .>= zero_0d) .& (indices_i32 .< size_0d)
98+ (ptr_tile, mask, size (indices))
99+ end
100+
101+ @inline function _atomic_ptrs_mask (array:: TileArray{T, 2} ,
102+ indices:: Tuple{Tile{I0}, Tile{I1}} ) where {T, I0 <: Integer , I1 <: Integer }
103+ idx0_0 = indices[1 ] .- one (I0)
104+ idx1_0 = indices[2 ] .- one (I1)
105+
106+ S = broadcast_shape (size (indices[1 ]), size (indices[2 ]))
107+ idx0_bc = broadcast_to (idx0_0, S)
108+ idx1_bc = broadcast_to (idx1_0, S)
109+
110+ idx0_i32 = convert (Tile{Int32}, idx0_bc)
111+ idx1_i32 = convert (Tile{Int32}, idx1_bc)
112+
113+ stride0_0d = Tile (array. strides[1 ])
114+ stride1_0d = Tile (array. strides[2 ])
115+ stride0 = broadcast_to (stride0_0d, S)
116+ stride1 = broadcast_to (stride1_0d, S)
117+
118+ linear_idx = idx0_i32 .* stride0 + idx1_i32 .* stride1
119+ ptr_tile = Intrinsics. offset (array. ptr, linear_idx)
120+
121+ zero_0d = Tile (Int32 (0 ))
122+ zero_bc = broadcast_to (zero_0d, S)
123+ size0_bc = broadcast_to (Tile (size (array, 1 )), S)
124+ size1_bc = broadcast_to (Tile (size (array, 2 )), S)
125+
126+ mask0 = (idx0_i32 .>= zero_bc) .& (idx0_i32 .< size0_bc)
127+ mask1 = (idx1_i32 .>= zero_bc) .& (idx1_i32 .< size1_bc)
128+ mask = mask0 .& mask1
129+
130+ (ptr_tile, mask, S)
131+ end
132+
133+ # --- RMW operations (atomic_add, atomic_xchg) ---
134+
135+ const _ATOMIC_RMW_OPS = (
136+ (:add , :atomic_add_tile ),
137+ (:xchg , :atomic_xchg_tile ),
138+ )
139+
140+ for (op, intrinsic) in _ATOMIC_RMW_OPS
141+ fname = Symbol (:atomic_ , op)
142+
143+ # 1D with scalar value
144+ @eval @inline function $fname (array:: TileArray{T, 1} , indices:: Tile{I} , val:: T ;
145+ memory_order:: Int = MemoryOrder. AcqRel,
146+ memory_scope:: Int = MemScope. Device) where {T, I <: Integer }
147+ ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
148+ val_tile = broadcast_to (Tile (val), S)
149+ Intrinsics.$ intrinsic (ptr_tile, val_tile, mask, memory_order, memory_scope)
150+ end
151+
152+ # 1D with tile value
153+ @eval @inline function $fname (array:: TileArray{T, 1} , indices:: Tile{I} , val:: Tile{T} ;
154+ memory_order:: Int = MemoryOrder. AcqRel,
155+ memory_scope:: Int = MemScope. Device) where {T, I <: Integer }
156+ ptr_tile, mask, _ = _atomic_ptrs_mask (array, indices)
157+ Intrinsics.$ intrinsic (ptr_tile, val, mask, memory_order, memory_scope)
158+ end
159+
160+ # 2D with scalar value
161+ @eval @inline function $fname (array:: TileArray{T, 2} ,
162+ indices:: Tuple{Tile{I0}, Tile{I1}} , val:: T ;
163+ memory_order:: Int = MemoryOrder. AcqRel,
164+ memory_scope:: Int = MemScope. Device) where {T, I0 <: Integer , I1 <: Integer }
165+ ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
166+ val_tile = broadcast_to (Tile (val), S)
167+ Intrinsics.$ intrinsic (ptr_tile, val_tile, mask, memory_order, memory_scope)
168+ end
169+
170+ # 2D with tile value
171+ @eval @inline function $fname (array:: TileArray{T, 2} ,
172+ indices:: Tuple{Tile{I0}, Tile{I1}} , val:: Tile{T} ;
173+ memory_order:: Int = MemoryOrder. AcqRel,
174+ memory_scope:: Int = MemScope. Device) where {T, I0 <: Integer , I1 <: Integer }
175+ ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
176+ val_bc = broadcast_to (val, S)
177+ Intrinsics.$ intrinsic (ptr_tile, val_bc, mask, memory_order, memory_scope)
178+ end
179+ end
180+
181+ # --- CAS operations (separate due to different signature) ---
182+
183+ # 1D with scalar expected/desired
184+ @inline function atomic_cas (array:: TileArray{T, 1} , indices:: Tile{I} ,
185+ expected:: T , desired:: T ;
186+ memory_order:: Int = MemoryOrder. AcqRel,
187+ memory_scope:: Int = MemScope. Device) where {T, I <: Integer }
188+ ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
189+ expected_tile = broadcast_to (Tile (expected), S)
190+ desired_tile = broadcast_to (Tile (desired), S)
191+ Intrinsics. atomic_cas_tile (ptr_tile, expected_tile, desired_tile, mask,
192+ memory_order, memory_scope)
193+ end
194+
195+ # 1D with tile expected/desired
196+ @inline function atomic_cas (array:: TileArray{T, 1} , indices:: Tile{I} ,
197+ expected:: Tile{T} , desired:: Tile{T} ;
198+ memory_order:: Int = MemoryOrder. AcqRel,
199+ memory_scope:: Int = MemScope. Device) where {T, I <: Integer }
200+ ptr_tile, mask, _ = _atomic_ptrs_mask (array, indices)
201+ Intrinsics. atomic_cas_tile (ptr_tile, expected, desired, mask,
202+ memory_order, memory_scope)
203+ end
204+
205+ # 2D with scalar expected/desired
206+ @inline function atomic_cas (array:: TileArray{T, 2} ,
207+ indices:: Tuple{Tile{I0}, Tile{I1}} ,
208+ expected:: T , desired:: T ;
209+ memory_order:: Int = MemoryOrder. AcqRel,
210+ memory_scope:: Int = MemScope. Device) where {T, I0 <: Integer , I1 <: Integer }
211+ ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
212+ expected_tile = broadcast_to (Tile (expected), S)
213+ desired_tile = broadcast_to (Tile (desired), S)
214+ Intrinsics. atomic_cas_tile (ptr_tile, expected_tile, desired_tile, mask,
215+ memory_order, memory_scope)
216+ end
217+
218+ # 2D with tile expected/desired
219+ @inline function atomic_cas (array:: TileArray{T, 2} ,
220+ indices:: Tuple{Tile{I0}, Tile{I1}} ,
221+ expected:: Tile{T} , desired:: Tile{T} ;
222+ memory_order:: Int = MemoryOrder. AcqRel,
223+ memory_scope:: Int = MemScope. Device) where {T, I0 <: Integer , I1 <: Integer }
224+ ptr_tile, mask, S = _atomic_ptrs_mask (array, indices)
225+ expected_bc = broadcast_to (expected, S)
226+ desired_bc = broadcast_to (desired, S)
227+ Intrinsics. atomic_cas_tile (ptr_tile, expected_bc, desired_bc, mask,
228+ memory_order, memory_scope)
229+ end
0 commit comments