@@ -30,114 +30,119 @@ function memory_scope_to_scope(scope::Int)
3030 end
3131end
3232
33+ """
34+ atomic_tfunc(ptrs) -> Type
35+
36+ Shared tfunc for atomic operations (add, xchg, cas).
37+ Returns raw T for 0D pointer tiles, Tile{T, S} for N-D.
38+ """
39+ function atomic_tfunc (𝕃, @nospecialize (ptrs), @nospecialize args... )
40+ ptrs_type = CC. widenconst (ptrs)
41+ ptrs_type isa DataType && ptrs_type <: Tile || return nothing
42+ ptr_type = eltype (ptrs_type)
43+ ptr_type <: Ptr || return nothing
44+ T = eltype (ptr_type)
45+ S = ptrs_type. parameters[2 ]
46+ S === Tuple{} && return T
47+ return Tile{T, S}
48+ end
49+
3350# cuda_tile.atomic_cas_tko
34- @intrinsic atomic_cas (array, index, expected, desired,
35- memory_order, memory_scope)
36- tfunc (𝕃, :: typeof (Intrinsics. atomic_cas), @nospecialize (array), @nospecialize args... ) = eltype (CC. widenconst (array))
51+ @intrinsic atomic_cas (ptr_tile, expected, desired, mask, memory_order, memory_scope)
52+ function tfunc (𝕃, :: typeof (Intrinsics. atomic_cas), @nospecialize (ptrs), @nospecialize args... )
53+ atomic_tfunc (𝕃, ptrs, args... )
54+ end
3755efunc (:: typeof (Intrinsics. atomic_cas), effects:: CC.Effects ) =
3856 CC. Effects (effects; effect_free= CC. ALWAYS_FALSE)
3957function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. atomic_cas), args)
4058 cb = ctx. cb
4159 tt = ctx. tt
4260
43- # args: (array, index, expected, desired, memory_order, memory_scope)
44- array_arg = args[1 ]
45-
46- # Get array info
47- arg_idx = extract_argument_index (array_arg)
48- is_tilearray = arg_idx != = nothing && is_destructured_arg (ctx, arg_idx)
61+ # args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
62+ ptr_tv = emit_value! (ctx, args[1 ])
63+ ptr_tv === nothing && throw (IRError (" atomic CAS requires ptr_tile" ))
64+ expected_tv = emit_value! (ctx, args[2 ])
65+ expected_tv === nothing && throw (IRError (" atomic CAS requires expected value" ))
66+ desired_tv = emit_value! (ctx, args[3 ])
67+ desired_tv === nothing && throw (IRError (" atomic CAS requires desired value" ))
4968
50- if ! is_tilearray
51- throw (IRError (" atomic_cas requires a TileArray argument" ))
52- end
69+ # Check if mask is provided (ghost Nothing = no mask)
70+ has_mask = get_constant (ctx, args[4 ]) != = nothing
5371
54- ptr_vals = get_arg_flat_values (ctx, arg_idx, :ptr )
55- isempty (ptr_vals) && throw (IRError (" Cannot get ptr from TileArray argument" ))
56- array_val = ptr_vals[1 ]
57- tilearray_type = get_arg_type (ctx, arg_idx)
58- elem_type = eltype (tilearray_type)
72+ memory_order = @something get_constant (ctx, args[5 ]) throw (IRError (" atomic CAS requires constant memory_order" ))
73+ memory_scope = @something get_constant (ctx, args[6 ]) throw (IRError (" atomic CAS requires constant memory_scope" ))
5974
60- # Get expected and desired values
61- expected_tv = emit_value! (ctx, args[3 ])
62- expected_tv === nothing && throw (IRError (" atomic_cas requires expected value" ))
63- desired_tv = emit_value! (ctx, args[4 ])
64- desired_tv === nothing && throw (IRError (" atomic_cas requires desired value" ))
75+ shape = ptr_tv. shape
6576
66- # Get memory order and scope from args
67- memory_order = @something get_constant (ctx, args[5 ]) throw (IRError (" atomic_cas requires constant memory_order" ))
68- memory_scope = @something get_constant (ctx, args[6 ]) throw (IRError (" atomic_cas requires constant memory_scope" ))
77+ # Get element type from pointer tile: Tile{Ptr{T}, S} -> T
78+ ptrs_type = CC. widenconst (ptr_tv. jltype)
79+ ptr_type = eltype (ptrs_type)
80+ elem_type = eltype (ptr_type)
6981
70- # Create result type (0D tile of element type)
7182 dtype = julia_to_tile_dtype! (tt, elem_type)
72- result_tile_type = tile_type! (tt, dtype, Int[] )
83+ result_tile_type = tile_type! (tt, dtype, collect (shape) )
7384 token_type = Token (tt)
7485
75- # Get index and create pointer type
76- index_tv = emit_value! (ctx, args[2 ])
77- ptr_type = pointer_type! (tt, dtype)
78- ptr_tile_type = tile_type! (tt, ptr_type, Int[])
79-
80- # Compute pointer using OffsetOp (handles any integer index type)
81- pointers = encode_OffsetOp! (cb, ptr_tile_type, array_val, index_tv. v)
82-
8386 # Emit atomic CAS
8487 mem_ordering = memory_order_to_semantics (memory_order)
8588 mem_scope = memory_scope_to_scope (memory_scope)
8689
87- old_val, new_token = encode_AtomicCASPtrOp! (cb, result_tile_type, token_type, pointers,
88- expected_tv. v, desired_tv. v;
89- token= ctx. token,
90- memory_ordering= mem_ordering,
91- memory_scope= mem_scope)
90+ if has_mask
91+ mask_tv = emit_value! (ctx, args[4 ])
92+ mask_tv === nothing && throw (IRError (" atomic CAS: cannot resolve mask" ))
93+ old_val, new_token = encode_AtomicCASPtrOp! (cb, result_tile_type, token_type,
94+ ptr_tv. v, expected_tv. v, desired_tv. v;
95+ mask= mask_tv. v,
96+ token= ctx. token,
97+ memory_ordering= mem_ordering,
98+ memory_scope= mem_scope)
99+ else
100+ old_val, new_token = encode_AtomicCASPtrOp! (cb, result_tile_type, token_type,
101+ ptr_tv. v, expected_tv. v, desired_tv. v;
102+ token= ctx. token,
103+ memory_ordering= mem_ordering,
104+ memory_scope= mem_scope)
105+ end
92106 ctx. token = new_token
93107
94- # Return scalar type (not Tile) to match the intrinsic signature
95- CGVal (old_val, result_tile_type, elem_type, Int[])
108+ # Return type depends on shape: raw T for 0D, Tile{T, S} for N-D
109+ if isempty (shape)
110+ CGVal (old_val, result_tile_type, elem_type, Int[])
111+ else
112+ CGVal (old_val, result_tile_type, Tile{elem_type, Tuple{shape... }}, collect (shape))
113+ end
96114end
97115
98116# cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations)
99117function emit_atomic_rmw! (ctx:: CGCtx , args:: AbstractVector , mode:: AtomicRMWMode )
100118 cb = ctx. cb
101119 tt = ctx. tt
102120
103- # args: (array, index, val, memory_order, memory_scope)
104- array_arg = args[1 ]
105-
106- # Get array info
107- arg_idx = extract_argument_index (array_arg)
108- is_tilearray = arg_idx != = nothing && is_destructured_arg (ctx, arg_idx)
121+ # args: (ptr_tile, val, mask, memory_order, memory_scope)
122+ ptr_tv = emit_value! (ctx, args[1 ])
123+ ptr_tv === nothing && throw (IRError (" atomic RMW requires ptr_tile" ))
124+ val_tv = emit_value! (ctx, args[2 ])
125+ val_tv === nothing && throw (IRError (" atomic RMW requires value" ))
109126
110- if ! is_tilearray
111- throw (IRError (" atomic operations require a TileArray argument" ))
112- end
127+ # Check if mask is provided (ghost Nothing = no mask)
128+ has_mask = get_constant (ctx, args[3 ]) != = nothing
113129
114- ptr_vals = get_arg_flat_values (ctx, arg_idx, :ptr )
115- isempty (ptr_vals) && throw (IRError (" Cannot get ptr from TileArray argument" ))
116- array_val = ptr_vals[1 ]
117- tilearray_type = get_arg_type (ctx, arg_idx)
118- elem_type = eltype (tilearray_type)
130+ # Get memory order and scope from args
131+ memory_order = @something get_constant (ctx, args[4 ]) throw (IRError (" atomic RMW requires constant memory_order" ))
132+ memory_scope = @something get_constant (ctx, args[5 ]) throw (IRError (" atomic RMW requires constant memory_scope" ))
119133
120- # Get update value
121- val_tv = emit_value! (ctx, args[3 ])
122- val_tv === nothing && throw (IRError (" atomic operation requires value" ))
134+ shape = ptr_tv. shape
123135
124- # Get memory order and scope from args
125- memory_order = @something get_constant (ctx, args[4 ]) throw (IRError (" atomic operation requires constant memory_order" ))
126- memory_scope = @something get_constant (ctx, args[5 ]) throw (IRError (" atomic operation requires constant memory_scope" ))
136+ # Get element type from pointer tile: Tile{Ptr{T}, S} -> T
137+ ptrs_type = CC. widenconst (ptr_tv. jltype)
138+ ptr_type = eltype (ptrs_type)
139+ elem_type = eltype (ptr_type)
127140
128- # Create result type (0D tile of element type)
141+ # Create result type
129142 dtype = julia_to_tile_dtype! (tt, elem_type)
130- result_tile_type = tile_type! (tt, dtype, Int[] )
143+ result_tile_type = tile_type! (tt, dtype, collect (shape) )
131144 token_type = Token (tt)
132145
133- # Get index and create pointer type
134- index_tv = emit_value! (ctx, args[2 ])
135- ptr_type = pointer_type! (tt, dtype)
136- ptr_tile_type = tile_type! (tt, ptr_type, Int[])
137-
138- # Compute pointer using OffsetOp (handles any integer index type)
139- pointers = encode_OffsetOp! (cb, ptr_tile_type, array_val, index_tv. v)
140-
141146 # Use float add mode for floating point types
142147 actual_mode = mode
143148 if mode == AtomicADD && elem_type <: AbstractFloat
@@ -148,148 +153,46 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
148153 mem_ordering = memory_order_to_semantics (memory_order)
149154 mem_scope = memory_scope_to_scope (memory_scope)
150155
151- old_val, new_token = encode_AtomicRMWPtrOp! (cb, result_tile_type, token_type, pointers,
152- val_tv. v, actual_mode;
153- token= ctx. token,
154- memory_ordering= mem_ordering,
155- memory_scope= mem_scope)
156+ if has_mask
157+ mask_tv = emit_value! (ctx, args[3 ])
158+ mask_tv === nothing && throw (IRError (" atomic RMW: cannot resolve mask" ))
159+ old_val, new_token = encode_AtomicRMWPtrOp! (cb, result_tile_type, token_type,
160+ ptr_tv. v, val_tv. v, actual_mode;
161+ mask= mask_tv. v,
162+ token= ctx. token,
163+ memory_ordering= mem_ordering,
164+ memory_scope= mem_scope)
165+ else
166+ old_val, new_token = encode_AtomicRMWPtrOp! (cb, result_tile_type, token_type,
167+ ptr_tv. v, val_tv. v, actual_mode;
168+ token= ctx. token,
169+ memory_ordering= mem_ordering,
170+ memory_scope= mem_scope)
171+ end
156172 ctx. token = new_token
157173
158- # Return scalar type (not Tile) to match the intrinsic signature
159- CGVal (old_val, result_tile_type, elem_type, Int[])
174+ # Return type depends on shape: raw T for 0D, Tile{T, S} for N-D
175+ if isempty (shape)
176+ CGVal (old_val, result_tile_type, elem_type, Int[])
177+ else
178+ CGVal (old_val, result_tile_type, Tile{elem_type, Tuple{shape... }}, collect (shape))
179+ end
160180end
161181
162182# cuda_tile.atomic_rmw_tko with XCHG
163- @intrinsic atomic_xchg (array, index, val , memory_order, memory_scope)
164- tfunc (𝕃, :: typeof (Intrinsics. atomic_xchg), @nospecialize (array), @nospecialize args... ) = eltype (CC . widenconst (array) )
183+ @intrinsic atomic_xchg (ptr_tile, val, mask , memory_order, memory_scope)
184+ tfunc (𝕃, :: typeof (Intrinsics. atomic_xchg), @nospecialize args... ) = atomic_tfunc (𝕃, args ... )
165185efunc (:: typeof (Intrinsics. atomic_xchg), effects:: CC.Effects ) =
166186 CC. Effects (effects; effect_free= CC. ALWAYS_FALSE)
167187function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. atomic_xchg), args)
168188 emit_atomic_rmw! (ctx, args, AtomicXCHG)
169189end
170190
171191# cuda_tile.atomic_rmw_tko with ADD
172- @intrinsic atomic_add (array, index, val,
173- memory_order, memory_scope)
174- tfunc (𝕃, :: typeof (Intrinsics. atomic_add), @nospecialize (array), @nospecialize args... ) = eltype (CC. widenconst (array))
192+ @intrinsic atomic_add (ptr_tile, val, mask, memory_order, memory_scope)
193+ tfunc (𝕃, :: typeof (Intrinsics. atomic_add), @nospecialize args... ) = atomic_tfunc (𝕃, args... )
175194efunc (:: typeof (Intrinsics. atomic_add), effects:: CC.Effects ) =
176195 CC. Effects (effects; effect_free= CC. ALWAYS_FALSE)
177196function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. atomic_add), args)
178197 emit_atomic_rmw! (ctx, args, AtomicADD)
179198end
180-
181- # ============================================================================
182- # Tile-indexed atomic operations
183- # These take pre-computed pointer tiles, value tiles, and masks.
184- # Used by the public API for tile-indexed atomic operations.
185- # ============================================================================
186-
187- # Shared codegen helper for tile-indexed atomic RMW operations
188- function emit_atomic_rmw_tile! (ctx:: CGCtx , args:: AbstractVector , mode:: AtomicRMWMode )
189- cb = ctx. cb
190- tt = ctx. tt
191-
192- # args: (ptr_tile, val, mask, memory_order, memory_scope)
193- ptr_tv = emit_value! (ctx, args[1 ])
194- ptr_tv === nothing && throw (IRError (" tile-indexed atomic RMW requires ptr_tile" ))
195- val_tv = emit_value! (ctx, args[2 ])
196- val_tv === nothing && throw (IRError (" tile-indexed atomic RMW requires value" ))
197- mask_tv = emit_value! (ctx, args[3 ])
198- mask_tv === nothing && throw (IRError (" tile-indexed atomic RMW requires mask" ))
199-
200- memory_order = @something get_constant (ctx, args[4 ]) throw (IRError (" tile-indexed atomic RMW requires constant memory_order" ))
201- memory_scope = @something get_constant (ctx, args[5 ]) throw (IRError (" tile-indexed atomic RMW requires constant memory_scope" ))
202-
203- shape = val_tv. shape
204- elem_type = eltype (val_tv. jltype)
205-
206- dtype = julia_to_tile_dtype! (tt, elem_type)
207- result_tile_type = tile_type! (tt, dtype, collect (shape))
208- token_type = Token (tt)
209-
210- # Auto-promote integer ADD to float ADD for floating-point types
211- actual_mode = mode
212- if mode == AtomicADD && elem_type <: AbstractFloat
213- actual_mode = AtomicADDF
214- end
215-
216- mem_ordering = memory_order_to_semantics (memory_order)
217- mem_scope = memory_scope_to_scope (memory_scope)
218-
219- old_val, new_token = encode_AtomicRMWPtrOp! (cb, result_tile_type, token_type,
220- ptr_tv. v, val_tv. v, actual_mode;
221- mask= mask_tv. v,
222- token= ctx. token,
223- memory_ordering= mem_ordering,
224- memory_scope= mem_scope)
225- ctx. token = new_token
226-
227- CGVal (old_val, result_tile_type, Tile{elem_type, Tuple{shape... }}, collect (shape))
228- end
229-
230- # Tile-indexed atomic exchange
231- @intrinsic atomic_xchg_tile (ptr_tile, val, mask, memory_order, memory_scope)
232- function tfunc (𝕃, :: typeof (Intrinsics. atomic_xchg_tile), @nospecialize (ptrs), @nospecialize (val), @nospecialize args... )
233- CC. widenconst (val)
234- end
235- efunc (:: typeof (Intrinsics. atomic_xchg_tile), effects:: CC.Effects ) =
236- CC. Effects (effects; effect_free= CC. ALWAYS_FALSE)
237- function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. atomic_xchg_tile), args)
238- emit_atomic_rmw_tile! (ctx, args, AtomicXCHG)
239- end
240-
241- # Tile-indexed atomic addition
242- @intrinsic atomic_add_tile (ptr_tile, val, mask, memory_order, memory_scope)
243- function tfunc (𝕃, :: typeof (Intrinsics. atomic_add_tile), @nospecialize (ptrs), @nospecialize (val), @nospecialize args... )
244- CC. widenconst (val)
245- end
246- efunc (:: typeof (Intrinsics. atomic_add_tile), effects:: CC.Effects ) =
247- CC. Effects (effects; effect_free= CC. ALWAYS_FALSE)
248- function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. atomic_add_tile), args)
249- emit_atomic_rmw_tile! (ctx, args, AtomicADD)
250- end
251-
252- # Tile-indexed atomic compare-and-swap
253- @intrinsic atomic_cas_tile (ptr_tile, expected, desired, mask, memory_order, memory_scope)
254- function tfunc (𝕃, :: typeof (Intrinsics. atomic_cas_tile), @nospecialize (ptrs), @nospecialize (expected), @nospecialize args... )
255- CC. widenconst (expected)
256- end
257- efunc (:: typeof (Intrinsics. atomic_cas_tile), effects:: CC.Effects ) =
258- CC. Effects (effects; effect_free= CC. ALWAYS_FALSE)
259- function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. atomic_cas_tile), args)
260- cb = ctx. cb
261- tt = ctx. tt
262-
263- # args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
264- ptr_tv = emit_value! (ctx, args[1 ])
265- ptr_tv === nothing && throw (IRError (" tile-indexed atomic CAS requires ptr_tile" ))
266- expected_tv = emit_value! (ctx, args[2 ])
267- expected_tv === nothing && throw (IRError (" tile-indexed atomic CAS requires expected value" ))
268- desired_tv = emit_value! (ctx, args[3 ])
269- desired_tv === nothing && throw (IRError (" tile-indexed atomic CAS requires desired value" ))
270- mask_tv = emit_value! (ctx, args[4 ])
271- mask_tv === nothing && throw (IRError (" tile-indexed atomic CAS requires mask" ))
272-
273- memory_order = @something get_constant (ctx, args[5 ]) throw (IRError (" tile-indexed atomic CAS requires constant memory_order" ))
274- memory_scope = @something get_constant (ctx, args[6 ]) throw (IRError (" tile-indexed atomic CAS requires constant memory_scope" ))
275-
276- shape = expected_tv. shape
277- elem_type = eltype (expected_tv. jltype)
278-
279- dtype = julia_to_tile_dtype! (tt, elem_type)
280- result_tile_type = tile_type! (tt, dtype, collect (shape))
281- token_type = Token (tt)
282-
283- mem_ordering = memory_order_to_semantics (memory_order)
284- mem_scope = memory_scope_to_scope (memory_scope)
285-
286- old_val, new_token = encode_AtomicCASPtrOp! (cb, result_tile_type, token_type,
287- ptr_tv. v, expected_tv. v, desired_tv. v;
288- mask= mask_tv. v,
289- token= ctx. token,
290- memory_ordering= mem_ordering,
291- memory_scope= mem_scope)
292- ctx. token = new_token
293-
294- CGVal (old_val, result_tile_type, Tile{elem_type, Tuple{shape... }}, collect (shape))
295- end
0 commit comments