[Codegen][CPU] Flatten contiguous trailing dims of transfers before unrolling.#24517
Draft
bjacob wants to merge 1 commit into
Draft
Conversation
238a90c to
7c255fb
Compare
10f904e to
3936fb5
Compare
df5471d to
266a791
Compare
…nrolling. `VectorTransferLoweringPass` runs the MLIR transfer-lowering patterns with `maxTransferRank=1` plus full-unroll, fully unrolling any rank-N>1 transfer to one rank-1 transfer per outer index. For a packed tile whose trailing dim is a tiny contiguous chunk that turns a single wide load into many narrow ones plus a shuffle chain to rebuild the wide register. Concretely, a bf16xbf16->f32 inner_tiled matmul (N=16, K_inner=2) loads each `<16x2xbf16>` RHS K-step as 16 separate `<2xbf16>` loads + a `vpermt2d`/`vpermt2q` chain -- ~3 cycles of extra work per K-step on top of the 29 dpbf16ps. Apply `populateFlattenVectorTransferPatterns` *before* rank reduction, gated on the target's natural word size (the pointer size, via `DataLayout`): flatten only when the trailing dim is *sub-word*. Sub-word loads in bulk are pathological; word-and-up trailing dims (`<2xf32>` ... `<16xf32>`) are already good standalone loads, and flattening *them* fuses register-sized rows into an oversized 1-D transfer + a `vector.shape_cast` re-split, regressing whole-model .vmfb size. (Not `native_vector_size`: that is the *widest* useful vector, not the smallest non-pathological load.) Measured: bf16 4096x4096 inner_tiled matmul on Zen 4, 80.8 -> 67.1 ms per fragment; combined with the m_bcst-fold broadcast routing in a sibling commit, the full matmul reaches ukernel parity (~50 ms). The `sdxl/clip_compstat_cpu` size guard is unchanged at 583k bytes / 2130 dispatches (golden 650k / 2130). Test fallout: `transpose_mask` in vector_lowering now writes a constant `vector<4x2xi1>` mask as a single flat `vector<8xi1>` store; updated the CHECK lines. Progress towards #24515. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
266a791 to
cdc74bd
Compare
6 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
VectorTransferLoweringPassapplies the MLIR transfer-lowering patterns withmaxTransferRank=1plus full-unroll, which fully unrolls any rank-N>1vector.transfer_read/transfer_writeto multiple rank-1 transfers (one per index of the outer dim). For multi-dim tiles whose trailing dims are contiguous in memory, this unrolls a single wide load into many narrow ones, which then have to be reassembled into a wide vector via a chain ofshufflevectors in the hot inner loop.Example surfacing the cost: a 4096x4096 dynamic-shape bf16xbf16->f32 matmul with
--iree-llvmcpu-enable-inner-tiledon Zen 4 lowered to inner_tiled with N=16, K_inner=2. The RHS for one K-step is avector<16x2xbf16>from a contiguous 64-byte slice. Unrolling to 16 separate<2 x bfloat>loads forced a sequence ofvpermt2d/vpermt2qper K-iteration in the inner loop to rebuild the wide RHS register — accounting for ~3 cycles of extra work per K-step on top of the 29 dpbf16ps doing the real work.Apply
populateFlattenVectorTransferPatternsbefore the rank-reduction patterns. It rewrites a multi-dim transfer with contiguous trailing dims into a transfer on amemref.collapse_shapeview + avector.shape_cast, so the read ends up as a single 1-D transfer over the collapsed view and lowers to one widevector.load. Per-fragment effect on the matmul benchmark above: 80.8 ms -> 67.1 ms (1.20x). Combined with the m_bcst-fold broadcast routing in a sibling commit, end-to-end gets to 53.4 ms (within 5% of the precompiled mmt4d ukernel at 50.9 ms).Test fallout: two pipelines now lower a per-row pack-tile load into a single wide load over a collapsed-memref view rather than one load per row (
aligned_unpack_genericin pipeline_pack_unpack_tests) / write a constantvector<4x2xi1>mask as a single flatvector<8xi1>store (transpose_maskin vector_lowering). The new IR is strictly fewer ops in both cases; updated the CHECK lines to match.Progress towards #24515.