Skip to content

Commit d82828d

Browse files
committed
Add workaround for folding constant arithmetic.
1 parent 11335a5 commit d82828d

2 files changed

Lines changed: 33 additions & 1 deletion

File tree

src/language/operations.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,21 +111,24 @@ end
111111
end
112112

113113
"""
114-
store(arr::TileArray, index, tile::Tile) -> Nothing
114+
store(arr::TileArray, index, tile::Tile) -> Tile
115115
116116
Store a tile to a TileArray at the given index. Index is 1-indexed.
117+
Returns the stored tile (enables chaining and helps constant folding).
117118
"""
118119
# Regular N-D tiles (N >= 1)
119120
@inline function store(arr::TileArray{T}, index, tile::Tile{T, Shape}) where {T, Shape}
120121
tv = Intrinsics.make_tensor_view(arr)
121122
pv = Intrinsics.make_partition_view(tv, Val(Shape), PaddingMode.Undetermined)
122123
Intrinsics.store_partition_view(pv, tile, (promote(index...) .- One())...)
124+
return tile # XXX: enables constant folding; remove when possible (see "constant folding" test)
123125
end
124126

125127
@inline function store(arr::TileArray{T}, index::Integer, tile::Tile{T, Shape}) where {T, Shape}
126128
tv = Intrinsics.make_tensor_view(arr)
127129
pv = Intrinsics.make_partition_view(tv, Val(Shape), PaddingMode.Undetermined)
128130
Intrinsics.store_partition_view(pv, tile, index - One())
131+
return tile # XXX: enables constant folding; remove when possible (see "constant folding" test)
129132
end
130133

131134
# Special case for 0D (scalar) tiles - reshape to 1D for partition view
@@ -135,13 +138,15 @@ end
135138
tile_1d = Intrinsics.reshape(tile, Val((1,)))
136139
pv = Intrinsics.make_partition_view(tv, Val((1,)), PaddingMode.Undetermined)
137140
Intrinsics.store_partition_view(pv, tile_1d, (promote(index...) .- One())...)
141+
return tile # XXX: enables constant folding; remove when possible (see "constant folding" test)
138142
end
139143

140144
@inline function store(arr::TileArray{T}, index::Integer, tile::Tile{T, ()}) where {T}
141145
tv = Intrinsics.make_tensor_view(arr)
142146
tile_1d = Intrinsics.reshape(tile, Val((1,)))
143147
pv = Intrinsics.make_partition_view(tv, Val((1,)), PaddingMode.Undetermined)
144148
Intrinsics.store_partition_view(pv, tile_1d, index - One())
149+
return tile # XXX: enables constant folding; remove when possible (see "constant folding" test)
145150
end
146151

147152
# Keyword argument version - dispatch to positional version

test/codegen.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,4 +1702,31 @@ end
17021702
end
17031703
end
17041704
end
1705+
1706+
#=========================================================================
1707+
Constant Folding
1708+
=========================================================================#
1709+
@testset "constant folding" begin
1710+
spec = ct.ArraySpec{1}(16, true)
1711+
1712+
# XXX: This test verifies that store() returns the tile to enable constant
1713+
# folding. If this test fails after removing `return tile` from store(),
1714+
# Julia's optimizer will emit subi operations for constant index math.
1715+
# See operations.jl store() for the workaround.
1716+
@testset "store with constant index folds subtraction" begin
1717+
@test @filecheck begin
1718+
@check_label "entry"
1719+
@check "load_view_tko"
1720+
# Verify no subi appears between load and store - constant 1-1 should fold to 0
1721+
@check_not "subi"
1722+
@check "store_view_tko"
1723+
code_tiled(Tuple{ct.TileArray{Float32,1,spec}}) do a
1724+
idx = Int32(1)
1725+
tile = ct.load(a, idx, (16,))
1726+
ct.store(a, idx, tile)
1727+
return
1728+
end
1729+
end
1730+
end
1731+
end
17051732
end

0 commit comments

Comments
 (0)