@@ -30,114 +30,113 @@ function memory_scope_to_scope(scope::Int)
3030 end
3131end
3232
33+ """
34+ atomic_tfunc(ptrs) -> Type
35+
36+ Shared tfunc for atomic operations (add, xchg, cas).
37+ Always returns Tile{T, S}, even for 0D (S = Tuple{}).
38+ """
39+ function atomic_tfunc (𝕃, @nospecialize (ptrs), @nospecialize args... )
40+ ptrs_type = CC. widenconst (ptrs)
41+ ptrs_type isa DataType && ptrs_type <: Tile || return nothing
42+ ptr_type = eltype (ptrs_type)
43+ ptr_type <: Ptr || return nothing
44+ T = eltype (ptr_type)
45+ S = ptrs_type. parameters[2 ]
46+ return Tile{T, S}
47+ end
48+
3349# cuda_tile.atomic_cas_tko
34- @intrinsic atomic_cas (array, index, expected, desired,
35- memory_order, memory_scope)
36- tfunc (𝕃, :: typeof (Intrinsics. atomic_cas), @nospecialize (array), @nospecialize args... ) = eltype (CC. widenconst (array))
50+ @intrinsic atomic_cas (ptr_tile, expected, desired, mask, memory_order, memory_scope)
51+ function tfunc (𝕃, :: typeof (Intrinsics. atomic_cas), @nospecialize (ptrs), @nospecialize args... )
52+ atomic_tfunc (𝕃, ptrs, args... )
53+ end
3754efunc (:: typeof (Intrinsics. atomic_cas), effects:: CC.Effects ) =
3855 CC. Effects (effects; effect_free= CC. ALWAYS_FALSE)
3956function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. atomic_cas), args)
4057 cb = ctx. cb
4158 tt = ctx. tt
4259
43- # args: (array, index, expected, desired, memory_order, memory_scope)
44- array_arg = args[1 ]
60+ # args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
61+ ptr_tv = emit_value! (ctx, args[1 ])
62+ ptr_tv === nothing && throw (IRError (" atomic CAS requires ptr_tile" ))
63+ expected_tv = emit_value! (ctx, args[2 ])
64+ expected_tv === nothing && throw (IRError (" atomic CAS requires expected value" ))
65+ desired_tv = emit_value! (ctx, args[3 ])
66+ desired_tv === nothing && throw (IRError (" atomic CAS requires desired value" ))
4567
46- # Get array info
47- arg_idx = extract_argument_index (array_arg)
48- is_tilearray = arg_idx != = nothing && is_destructured_arg (ctx, arg_idx)
68+ # Check if mask is provided (ghost Nothing = no mask)
69+ has_mask = get_constant (ctx, args[4 ]) != = nothing
4970
50- if ! is_tilearray
51- throw (IRError (" atomic_cas requires a TileArray argument" ))
52- end
71+ memory_order = @something get_constant (ctx, args[5 ]) throw (IRError (" atomic CAS requires constant memory_order" ))
72+ memory_scope = @something get_constant (ctx, args[6 ]) throw (IRError (" atomic CAS requires constant memory_scope" ))
5373
54- ptr_vals = get_arg_flat_values (ctx, arg_idx, :ptr )
55- isempty (ptr_vals) && throw (IRError (" Cannot get ptr from TileArray argument" ))
56- array_val = ptr_vals[1 ]
57- tilearray_type = get_arg_type (ctx, arg_idx)
58- elem_type = eltype (tilearray_type)
74+ shape = ptr_tv. shape
5975
60- # Get expected and desired values
61- expected_tv = emit_value! (ctx, args[3 ])
62- expected_tv === nothing && throw (IRError (" atomic_cas requires expected value" ))
63- desired_tv = emit_value! (ctx, args[4 ])
64- desired_tv === nothing && throw (IRError (" atomic_cas requires desired value" ))
76+ # Get element type from pointer tile: Tile{Ptr{T}, S} -> T
77+ ptrs_type = CC. widenconst (ptr_tv. jltype)
78+ ptr_type = eltype (ptrs_type)
79+ elem_type = eltype (ptr_type)
6580
66- # Get memory order and scope from args
67- memory_order = @something get_constant (ctx, args[5 ]) throw (IRError (" atomic_cas requires constant memory_order" ))
68- memory_scope = @something get_constant (ctx, args[6 ]) throw (IRError (" atomic_cas requires constant memory_scope" ))
69-
70- # Create result type (0D tile of element type)
7181 dtype = julia_to_tile_dtype! (tt, elem_type)
72- result_tile_type = tile_type! (tt, dtype, Int[] )
82+ result_tile_type = tile_type! (tt, dtype, collect (shape) )
7383 token_type = Token (tt)
7484
75- # Get index and create pointer type
76- index_tv = emit_value! (ctx, args[2 ])
77- ptr_type = pointer_type! (tt, dtype)
78- ptr_tile_type = tile_type! (tt, ptr_type, Int[])
79-
80- # Compute pointer using OffsetOp (handles any integer index type)
81- pointers = encode_OffsetOp! (cb, ptr_tile_type, array_val, index_tv. v)
82-
8385 # Emit atomic CAS
8486 mem_ordering = memory_order_to_semantics (memory_order)
8587 mem_scope = memory_scope_to_scope (memory_scope)
8688
87- old_val, new_token = encode_AtomicCASPtrOp! (cb, result_tile_type, token_type, pointers,
88- expected_tv. v, desired_tv. v;
89- token= ctx. token,
90- memory_ordering= mem_ordering,
91- memory_scope= mem_scope)
89+ old_val, new_token = if has_mask
90+ mask_tv = emit_value! (ctx, args[4 ])
91+ mask_tv === nothing && throw (IRError (" atomic CAS: cannot resolve mask" ))
92+ encode_AtomicCASPtrOp! (cb, result_tile_type, token_type,
93+ ptr_tv. v, expected_tv. v, desired_tv. v;
94+ mask= mask_tv. v,
95+ token= ctx. token,
96+ memory_ordering= mem_ordering,
97+ memory_scope= mem_scope)
98+ else
99+ encode_AtomicCASPtrOp! (cb, result_tile_type, token_type,
100+ ptr_tv. v, expected_tv. v, desired_tv. v;
101+ token= ctx. token,
102+ memory_ordering= mem_ordering,
103+ memory_scope= mem_scope)
104+ end
92105 ctx. token = new_token
93106
94- # Return scalar type (not Tile) to match the intrinsic signature
95- CGVal (old_val, result_tile_type, elem_type, Int[])
107+ CGVal (old_val, result_tile_type, Tile{elem_type, Tuple{shape... }}, collect (shape))
96108end
97109
98110# cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations)
99111function emit_atomic_rmw! (ctx:: CGCtx , args:: AbstractVector , mode:: AtomicRMWMode )
100112 cb = ctx. cb
101113 tt = ctx. tt
102114
103- # args: (array, index, val, memory_order, memory_scope)
104- array_arg = args[1 ]
115+ # args: (ptr_tile, val, mask, memory_order, memory_scope)
116+ ptr_tv = emit_value! (ctx, args[1 ])
117+ ptr_tv === nothing && throw (IRError (" atomic RMW requires ptr_tile" ))
118+ val_tv = emit_value! (ctx, args[2 ])
119+ val_tv === nothing && throw (IRError (" atomic RMW requires value" ))
105120
106- # Get array info
107- arg_idx = extract_argument_index (array_arg)
108- is_tilearray = arg_idx != = nothing && is_destructured_arg (ctx, arg_idx)
121+ # Check if mask is provided (ghost Nothing = no mask)
122+ has_mask = get_constant (ctx, args[3 ]) != = nothing
109123
110- if ! is_tilearray
111- throw (IRError (" atomic operations require a TileArray argument " ))
112- end
124+ # Get memory order and scope from args
125+ memory_order = @something get_constant (ctx, args[ 4 ]) throw (IRError (" atomic RMW requires constant memory_order " ))
126+ memory_scope = @something get_constant (ctx, args[ 5 ]) throw ( IRError ( " atomic RMW requires constant memory_scope " ))
113127
114- ptr_vals = get_arg_flat_values (ctx, arg_idx, :ptr )
115- isempty (ptr_vals) && throw (IRError (" Cannot get ptr from TileArray argument" ))
116- array_val = ptr_vals[1 ]
117- tilearray_type = get_arg_type (ctx, arg_idx)
118- elem_type = eltype (tilearray_type)
128+ shape = ptr_tv. shape
119129
120- # Get update value
121- val_tv = emit_value! (ctx, args[3 ])
122- val_tv === nothing && throw (IRError (" atomic operation requires value" ))
130+ # Get element type from pointer tile: Tile{Ptr{T}, S} -> T
131+ ptrs_type = CC. widenconst (ptr_tv. jltype)
132+ ptr_type = eltype (ptrs_type)
133+ elem_type = eltype (ptr_type)
123134
124- # Get memory order and scope from args
125- memory_order = @something get_constant (ctx, args[4 ]) throw (IRError (" atomic operation requires constant memory_order" ))
126- memory_scope = @something get_constant (ctx, args[5 ]) throw (IRError (" atomic operation requires constant memory_scope" ))
127-
128- # Create result type (0D tile of element type)
135+ # Create result type
129136 dtype = julia_to_tile_dtype! (tt, elem_type)
130- result_tile_type = tile_type! (tt, dtype, Int[] )
137+ result_tile_type = tile_type! (tt, dtype, collect (shape) )
131138 token_type = Token (tt)
132139
133- # Get index and create pointer type
134- index_tv = emit_value! (ctx, args[2 ])
135- ptr_type = pointer_type! (tt, dtype)
136- ptr_tile_type = tile_type! (tt, ptr_type, Int[])
137-
138- # Compute pointer using OffsetOp (handles any integer index type)
139- pointers = encode_OffsetOp! (cb, ptr_tile_type, array_val, index_tv. v)
140-
141140 # Use float add mode for floating point types
142141 actual_mode = mode
143142 if mode == AtomicADD && elem_type <: AbstractFloat
@@ -148,30 +147,39 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
148147 mem_ordering = memory_order_to_semantics (memory_order)
149148 mem_scope = memory_scope_to_scope (memory_scope)
150149
151- old_val, new_token = encode_AtomicRMWPtrOp! (cb, result_tile_type, token_type, pointers,
152- val_tv. v, actual_mode;
153- token= ctx. token,
154- memory_ordering= mem_ordering,
155- memory_scope= mem_scope)
150+ old_val, new_token = if has_mask
151+ mask_tv = emit_value! (ctx, args[3 ])
152+ mask_tv === nothing && throw (IRError (" atomic RMW: cannot resolve mask" ))
153+ encode_AtomicRMWPtrOp! (cb, result_tile_type, token_type,
154+ ptr_tv. v, val_tv. v, actual_mode;
155+ mask= mask_tv. v,
156+ token= ctx. token,
157+ memory_ordering= mem_ordering,
158+ memory_scope= mem_scope)
159+ else
160+ encode_AtomicRMWPtrOp! (cb, result_tile_type, token_type,
161+ ptr_tv. v, val_tv. v, actual_mode;
162+ token= ctx. token,
163+ memory_ordering= mem_ordering,
164+ memory_scope= mem_scope)
165+ end
156166 ctx. token = new_token
157167
158- # Return scalar type (not Tile) to match the intrinsic signature
159- CGVal (old_val, result_tile_type, elem_type, Int[])
168+ CGVal (old_val, result_tile_type, Tile{elem_type, Tuple{shape... }}, collect (shape))
160169end
161170
162171# cuda_tile.atomic_rmw_tko with XCHG
163- @intrinsic atomic_xchg (array, index, val , memory_order, memory_scope)
164- tfunc (𝕃, :: typeof (Intrinsics. atomic_xchg), @nospecialize (array), @nospecialize args... ) = eltype (CC . widenconst (array) )
172+ @intrinsic atomic_xchg (ptr_tile, val, mask , memory_order, memory_scope)
173+ tfunc (𝕃, :: typeof (Intrinsics. atomic_xchg), @nospecialize args... ) = atomic_tfunc (𝕃, args ... )
165174efunc (:: typeof (Intrinsics. atomic_xchg), effects:: CC.Effects ) =
166175 CC. Effects (effects; effect_free= CC. ALWAYS_FALSE)
167176function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. atomic_xchg), args)
168177 emit_atomic_rmw! (ctx, args, AtomicXCHG)
169178end
170179
171180# cuda_tile.atomic_rmw_tko with ADD
172- @intrinsic atomic_add (array, index, val,
173- memory_order, memory_scope)
174- tfunc (𝕃, :: typeof (Intrinsics. atomic_add), @nospecialize (array), @nospecialize args... ) = eltype (CC. widenconst (array))
181+ @intrinsic atomic_add (ptr_tile, val, mask, memory_order, memory_scope)
182+ tfunc (𝕃, :: typeof (Intrinsics. atomic_add), @nospecialize args... ) = atomic_tfunc (𝕃, args... )
175183efunc (:: typeof (Intrinsics. atomic_add), effects:: CC.Effects ) =
176184 CC. Effects (effects; effect_free= CC. ALWAYS_FALSE)
177185function emit_intrinsic! (ctx:: CGCtx , :: typeof (Intrinsics. atomic_add), args)
0 commit comments