11# Broadcasting Infrastructure for Tiles
22#
33# Defines the broadcast style and shape computation for Tile types.
4- # All broadcasted operations are materialized via copy → map .
4+ # All broadcasted operations are materialized via copy.
55
66import Base. Broadcast: BroadcastStyle, Broadcasted, broadcastable, broadcast_shape
77
@@ -24,26 +24,38 @@ Base.Broadcast.broadcastable(t::Tile) = t
2424
2525
2626#= ============================================================================
27- Broadcast materialization via copy + map
27+ Ghost wrapper for Type values in broadcasting
28+ =============================================================================#
29+
30+ # Replaces Julia's RefValue{Type{T}} wrapping which the cuTile compiler can't construct.
31+ # The value is encoded in the type parameter — no runtime representation needed.
32+ struct TypeRef{T} end
33+
34+ Base. Broadcast. BroadcastStyle (:: Type{<:TypeRef} ) = Base. Broadcast. DefaultArrayStyle {0} ()
35+ Base. Broadcast. broadcastable (a:: TypeRef ) = a
36+
37+
38+ #= ============================================================================
39+ Broadcast materialization via copy
2840=============================================================================#
2941
3042# Tile is a ghost type with no storage, so axes/size are meaningless.
3143# Skip instantiate (which calls axes) by returning the Broadcasted as-is.
3244@inline Base. Broadcast. instantiate (bc:: Broadcasted{TileStyle} ) = bc
3345
3446# Recursively materialize nested Broadcasted nodes,
35- # promote scalars to Tiles, broadcast to a common shape, then apply via map .
47+ # promote scalars to Tiles, broadcast to a common shape, then apply f .
3648# This handles all element-wise operations: scalar @overlay methods provide
3749# the implementation for overlaid ops, while Julia's native scalar functions
3850# (compiled to Core intrinsics) handle the rest. Mixed-type and type-changing
3951# operations (comparisons, ifelse) are supported by the mixed-type map methods
4052# in operations.jl.
4153@inline function Base. copy (bc:: Broadcasted{TileStyle} )
4254 args = _materialize_args (bc. args)
43- tiles = _promote_to_tiles (args... )
44- S = _broadcast_shapes (tiles ... )
45- broadcasted = _broadcast_all (S, tiles ... )
46- map (bc. f, broadcasted... )
55+ promoted = _promote_to_tiles (args... )
56+ S = _broadcast_shapes (promoted ... )
57+ broadcasted = _broadcast_all (S, promoted ... )
58+ _apply_broadcast (bc. f, broadcasted... )
4759end
4860
4961# Recursively materialize nested Broadcasted nodes into concrete Tiles.
6375# using its own type (e.g., 0.0f0 → Tile(Float32(0.0))), preserving the
6476# type that Julia's broadcast promotion chose. This avoids the pitfall of
6577# using the first Tile's eltype (which could be Bool for ifelse conditions).
78+ # TypeRef arguments pass through unchanged — they carry no tile shape.
6679@inline _promote_to_tiles () = ()
6780@inline _promote_to_tiles (a:: Tile , rest... ) = (a, _promote_to_tiles (rest... )... )
6881@inline _promote_to_tiles (a:: T , rest... ) where {T <: Number } =
6982 (Tile (a), _promote_to_tiles (rest... )... )
83+ @inline _promote_to_tiles (a:: TypeRef , rest... ) = (a, _promote_to_tiles (rest... )... )
7084
7185# Compute combined broadcast shape across all Tile arguments via tuple peeling.
7286# Shape is always a tuple TYPE (e.g., Tuple{16, 32}). Convert to value for broadcast_shape.
87+ # TypeRef arguments are skipped — they have no shape.
7388@inline _tile_shape (t:: Tile ) = size (t)
7489@inline _broadcast_shapes (t:: Tile ) = _tile_shape (t)
75- @inline _broadcast_shapes (t:: Tile , rest:: Tile ... ) =
90+ @inline _broadcast_shapes (t:: Tile , rest... ) =
7691 broadcast_shape (_tile_shape (t), _broadcast_shapes (rest... ))
92+ @inline _broadcast_shapes (:: TypeRef , rest... ) = _broadcast_shapes (rest... )
93+ @inline _broadcast_shapes (:: TypeRef ) = ()
7794
7895# Broadcast all tiles to shape S via tuple peeling.
96+ # TypeRef arguments pass through unchanged.
7997@inline _broadcast_all (S:: Tuple ) = ()
80- @inline _broadcast_all (S:: Tuple , a:: Tile , rest:: Tile ... ) =
98+ @inline _broadcast_all (S:: Tuple , a:: Tile , rest... ) =
8199 (broadcast_to (a, S), _broadcast_all (S, rest... )... )
100+ @inline _broadcast_all (S:: Tuple , a:: TypeRef , rest... ) =
101+ (a, _broadcast_all (S, rest... )... )
102+
103+ # Convert args to scalars, apply f, wrap result back into a Tile.
104+ @inline function _apply_broadcast (f, args... )
105+ scalar_args, S = _to_scalars (args... )
106+ Intrinsics. from_scalar (f (scalar_args... ), S)
107+ end
108+
109+ # Reinterpret Tile arguments as scalars for broadcast application.
110+ # Skip and extract TypeRef arguments.
111+ # Returns (scalar_args_tuple, S) where S is the shape from the first Tile.
112+ @inline _to_scalars (t:: Tile{<:Any,S} ) where S = ((Intrinsics. to_scalar (t),), S)
113+ @inline function _to_scalars (t:: Tile{<:Any,S} , rest... ) where S
114+ rest_scalars, _ = _to_scalars (rest... )
115+ ((Intrinsics. to_scalar (t), rest_scalars... ), S)
116+ end
117+ @inline function _to_scalars (:: TypeRef{T} , rest... ) where T
118+ rest_scalars, S = _to_scalars (rest... )
119+ ((T, rest_scalars... ), S)
120+ end
0 commit comments