Skip to content

Commit cdc74bd

Browse files
bjacobclaude
andcommitted
[Codegen][CPU] Flatten contiguous trailing dims of transfers before unrolling.
`VectorTransferLoweringPass` runs the MLIR transfer-lowering patterns with `maxTransferRank=1` plus full-unroll, fully unrolling any rank-N>1 transfer to one rank-1 transfer per outer index. For a packed tile whose trailing dim is a tiny contiguous chunk that turns a single wide load into many narrow ones plus a shuffle chain to rebuild the wide register. Concretely, a bf16xbf16->f32 inner_tiled matmul (N=16, K_inner=2) loads each `<16x2xbf16>` RHS K-step as 16 separate `<2xbf16>` loads + a `vpermt2d`/`vpermt2q` chain -- ~3 cycles of extra work per K-step on top of the 29 dpbf16ps. Apply `populateFlattenVectorTransferPatterns` *before* rank reduction, gated on the target's natural word size (the pointer size, via `DataLayout`): flatten only when the trailing dim is *sub-word*. Sub-word loads in bulk are pathological; word-and-up trailing dims (`<2xf32>` ... `<16xf32>`) are already good standalone loads, and flattening *them* fuses register-sized rows into an oversized 1-D transfer + a `vector.shape_cast` re-split, regressing whole-model .vmfb size. (Not `native_vector_size`: that is the *widest* useful vector, not the smallest non-pathological load.) Measured: bf16 4096x4096 inner_tiled matmul on Zen 4, 80.8 -> 67.1 ms per fragment; combined with the m_bcst-fold broadcast routing in a sibling commit, the full matmul reaches ukernel parity (~50 ms). The `sdxl/clip_compstat_cpu` size guard is unchanged at 583k bytes / 2130 dispatches (golden 650k / 2130). Test fallout: `transpose_mask` in vector_lowering now writes a constant `vector<4x2xi1>` mask as a single flat `vector<8xi1>` store; updated the CHECK lines. Progress towards #24515. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
1 parent 3936fb5 commit cdc74bd

5 files changed

Lines changed: 36 additions & 1 deletion

File tree

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ iree_compiler_cc_library(
234234
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
235235
"//compiler/src/iree/compiler/Utils",
236236
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
237+
"@llvm-project//llvm:Core",
237238
"@llvm-project//llvm:Support",
238239
"@llvm-project//mlir:AMDGPUDialect",
239240
"@llvm-project//mlir:AMDGPUTransforms",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ iree_cc_library(
180180
::PassHeaders
181181
::PassesIncGen
182182
IREELinalgTransformDialect
183+
LLVMCore
183184
LLVMSupport
184185
MLIRAMDGPUDialect
185186
MLIRAMDGPUTransforms

compiler/src/iree/compiler/Codegen/Common/VectorTransferLowering.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
9+
#include "llvm/IR/DataLayout.h"
810
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
911
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1012
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -38,6 +40,32 @@ void VectorTransferLoweringPass::runOnOperation() {
3840
MLIRContext *ctx = &getContext();
3941
mlir::FunctionOpInterface funcOp = getOperation();
4042

43+
// Flatten contiguous trailing dims of multi-dim transfers when the trailing
44+
// dim is narrower than the target's natural word (the pointer size), so a
45+
// packed `<16x2xbf16>` (32-bit innermost) lowers to one wide load instead
46+
// of 16 narrow loads the rank reduction below would reassemble with a
47+
// chain of shuffles. Sub-word loads in bulk are uniformly pathological;
48+
// word-and-up loads (`<2xf32>` ... `<16xf32>`) are already fine and
49+
// flattening *them* fuses register-sized rows into an oversized 1-D
50+
// transfer + a `vector.shape_cast` re-split (extracts), regressing whole-
51+
// model .vmfb size for no benefit. This is *not* `native_vector_size`:
52+
// that is the *widest* useful vector, not the smallest non-pathological
53+
// load.
54+
unsigned pointerBits = 64;
55+
if (auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp)) {
56+
if (auto attr =
57+
targetAttr.getConfiguration().getAs<StringAttr>("data_layout")) {
58+
if (!attr.getValue().empty()) {
59+
pointerBits = llvm::DataLayout(attr.getValue()).getPointerSizeInBits();
60+
}
61+
}
62+
}
63+
{
64+
RewritePatternSet patterns(ctx);
65+
vector::populateFlattenVectorTransferPatterns(patterns, pointerBits);
66+
(void)applyPatternsGreedily(funcOp, std::move(patterns));
67+
}
68+
4169
RewritePatternSet patterns(ctx);
4270
// Explicitly materialize the mask on transfer_read/transfer_write.
4371
// Assume we don't have 4 GB vectors.

compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_pack_unpack_tests.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ module {
8080
// CHECK-LABEL: func.func @aligned_unpack_generic
8181
// CHECK: %[[SRC:.+]] = hal.interface.binding.subspan {{.*}} : memref<24x32x16x16xf32, #hal.descriptor_type<storage_buffer>>
8282
// CHECK: %[[ASSUMED_SRC:.+]] = memref.assume_alignment %[[SRC]], 64
83+
// The unpack source tile is `vector<16x16xf32>`: its trailing dim is a full
84+
// 512-bit `vector<16xf32>`, so transfer flattening leaves it alone and plain
85+
// rank reduction lowers it to one `vector<16xf32>` load per row.
8386
// CHECK-COUNT-15: vector.load %[[ASSUMED_SRC]]
8487
// CHECK: %[[LAST_LOAD:.+]] = vector.load %[[ASSUMED_SRC]]
8588
// CHECK: %[[IN_0:.+]] = vector.broadcast %{{.+}} : vector<16xf32> to vector<16x16xf32>

compiler/src/iree/compiler/Codegen/LLVMCPU/test/vector_lowering.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ func.func @transpose_mask() {
155155
// CHECK-NOT: vector.shuffle
156156
// CHECK-DAG: %[[MASK:.+]] = arith.constant dense<true>
157157
// CHECK-DAG: %[[OUTPUT:.+]] = hal.interface.binding.subspan
158-
// CHECK: vector.store %[[MASK]], %[[OUTPUT]]
158+
// VectorTransferLoweringPass flattens the contiguous 4x2 trailing dims of
159+
// the store into a single `vector<8xi1>` store over the collapsed memref.
160+
// CHECK: vector.store %[[MASK]], %{{.+}}
159161

160162
// -----
161163

0 commit comments

Comments
 (0)