Skip to content

Commit 560445b

Browse files
CrispStrobeggerganov
authored andcommitted
metal : tighten input-position loop in kernel_conv_transpose_1d (ggml/1477)
For a given output position j on the time axis, only input positions i such that i*s0 <= j < i*s0 + K contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)] intersected with [0, IL-1]. That's at most ceil(K/s0) values (typically 2 for stride==K/2 transposed convs). The current kernel iterates the full IL range and filters with an `if`, amplifying per-thread work by IL/ceil(K/s0) (~160x for IL=320, K=10, s0=5 -- a representative codec-decoder shape). On Apple M1 the wasted work trips the macOS GPU watchdog (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) on long graphs. Compute i_min, i_max analytically before the inner loop and iterate only [i_min, i_max]. Output is bit-identical (same multiplies and adds in the same order); loop bound shrinks by IL/ceil(K/s0). Tested on M1 with a downstream consumer running a TTS codec at full T_codec; end-to-end codec decode ~3-4x faster, zero watchdog hits across long synthesis runs vs ~30% pre-patch.
1 parent 2eb3e6b commit 560445b

1 file changed

Lines changed: 24 additions & 7 deletions

File tree

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4881,15 +4881,32 @@ kernel void kernel_conv_transpose_1d(
48814881
uint3 tgpig[[threadgroup_position_in_grid]],
48824882
uint3 tgpg[[threadgroups_per_grid]]) {
48834883

4884-
float v = 0.0f;
4884+
// For output position j on the time axis, only input positions
4885+
// i such that i*s0 <= j < i*s0 + K
4886+
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
4887+
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
4888+
// (typically 2 for stride==K/2 transposed convs).
4889+
const int32_t j = tgpig[0];
4890+
const int32_t s0 = args.s0;
4891+
const int32_t K = args.K;
4892+
const int32_t IL = args.IL;
4893+
4894+
int32_t i_min;
4895+
{
4896+
int32_t a = j - K + 1;
4897+
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
4898+
}
4899+
int32_t i_max = j / s0;
4900+
if (i_max > IL - 1) i_max = IL - 1;
48854901

4886-
for (int64_t c = 0; c < args.IC; c++) {
4887-
const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
4888-
const int32_t input_offset = c * args.IL;
4902+
float v = 0.0f;
4903+
if (i_min <= i_max) {
4904+
for (int64_t c = 0; c < args.IC; c++) {
4905+
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
4906+
const int32_t input_offset = c * IL;
48894907

4890-
for (int64_t i = 0; i < args.IL; i++) {
4891-
if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
4892-
v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
4908+
for (int32_t i = i_min; i <= i_max; i++) {
4909+
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
48934910
}
48944911
}
48954912
}

0 commit comments

Comments
 (0)