Skip to content

Commit ad4dea5

Browse files
kasper0406claude
andcommitted
dot/schedule_bench: fix A_stride_m for transpose_a kernels
Dot kernels that set dot_flag::transpose_a (e.g. the SME/SME2 ones) consume A packed as [k1/tile_k, k3, k2, {i, tile_k}]. They advance A_k1 per inner-k step by `A_stride_m * dot_factor` bytes, but the semantically correct value for that slot is the stride ALONG the k1 dimension of the packed tensor — which is `a_k_strides[0]`, not the (i, tile_k) intra-row stride that schedule_bench was passing as `a_stride_m`. subgraph/dot.cc already does this swap in `call_kernel`; without it in schedule_bench, a K step advances A_k1 by one element (4 bytes for fp32) instead of one full packed row (M×sizeof(TA) bytes). The kernel still runs and the built-in correctness check passes because A and B are filled with 1s, but the measured bandwidth and cache behaviour don't match the production path — benchmark numbers reported against this bench were artificially hot because successive k-steps re-read near-identical addresses. Mirror the transposed_a-aware swap from subgraph/dot.cc so schedule_bench measures the same work the production path executes. After the fix, `dot_fp32_sme2` on the new (auto:8 MiB) schedule reports realistic numbers that track the subgraph production bench. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ef69e72 commit ad4dea5

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

ynnpack/kernels/dot/schedule_bench.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,12 @@ double run_benchmark(TA, TB, TC, const kernel_info& kernel, size_t m, size_t n,
152152
size_t a_stride_m, span<const size_t> a_k_strides, const void* b_ptr,
153153
span<const size_t> b_k_strides, size_t init_c_stride_m,
154154
const void* init_c, void* c_ptr) {
155-
kernel.kernel(m, n, k[2], k[1], k[0], a_stride_m,
155+
// For dot_flag::transpose_a kernels, the 6th kernel arg is the
156+
// stride of the k1/tile_k dimension of the packed A (see dot.h),
157+
// not the m stride. subgraph/dot.cc does the same swap — mirror
158+
// it here.
159+
kernel.kernel(m, n, k[2], k[1], k[0],
160+
pack_a ? a_k_strides[0] : a_stride_m,
156161
a_k_strides[2], a_k_strides[1], a_ptr, b_k_strides[2],
157162
b_k_strides[1], b_k_strides[0], b_ptr, init_c_stride_m,
158163
init_c, c.stride(0) * sizeof(TC), c_ptr);

0 commit comments

Comments
 (0)