Skip to content

Commit d43375f

Browse files
authored
ggml : fix RWKV ops thread assignment (ggml-org#21226)
1 parent 2b86e5c commit d43375f

2 files changed

Lines changed: 14 additions & 22 deletions

File tree

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23502350
case GGML_OP_FLASH_ATTN_BACK:
23512351
case GGML_OP_SSM_CONV:
23522352
case GGML_OP_SSM_SCAN:
2353+
{
2354+
n_tasks = n_threads;
2355+
} break;
23532356
case GGML_OP_RWKV_WKV6:
23542357
case GGML_OP_GATED_LINEAR_ATTN:
23552358
case GGML_OP_RWKV_WKV7:
23562359
{
2357-
n_tasks = n_threads;
2360+
const int64_t n_heads = node->src[1]->ne[1];
2361+
n_tasks = MIN(n_threads, n_heads);
23582362
} break;
23592363
case GGML_OP_WIN_PART:
23602364
case GGML_OP_WIN_UNPART:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
99539953
const int ith = params->ith;
99549954
const int nth = params->nth;
99559955

9956-
if (ith >= HEADS) {
9957-
return;
9958-
}
9959-
9960-
const int h_start = (HEADS * ith) / nth;
9961-
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9962-
(HEADS * (ith + 1)) / nth : HEADS;
9956+
const int h_start = (HEADS * (ith )) / nth;
9957+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9958+
(HEADS * (ith + 1)) / nth : HEADS;
99639959

99649960
float * k = (float *) dst->src[0]->data;
99659961
float * v = (float *) dst->src[1]->data;
@@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32(
1017010166
const int ith = params->ith;
1017110167
const int nth = params->nth;
1017210168

10173-
if (ith >= HEADS) {
10174-
return;
10175-
}
10176-
10177-
const int h_start = (HEADS * ith) / nth;
10178-
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10179-
(HEADS * (ith + 1)) / nth : HEADS;
10169+
const int h_start = (HEADS * (ith )) / nth;
10170+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10171+
(HEADS * (ith + 1)) / nth : HEADS;
1018010172

1018110173
float * k = (float *) dst->src[0]->data;
1018210174
float * v = (float *) dst->src[1]->data;
@@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
1063310625
const int ith = params->ith;
1063410626
const int nth = params->nth;
1063510627

10636-
if (ith >= HEADS) {
10637-
return;
10638-
}
10639-
10640-
const int h_start = (HEADS * ith) / nth;
10641-
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10642-
(HEADS * (ith + 1)) / nth : HEADS;
10628+
const int h_start = (HEADS * (ith )) / nth;
10629+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10630+
(HEADS * (ith + 1)) / nth : HEADS;
1064310631

1064410632
float * r = (float *) dst->src[0]->data;
1064510633
float * w = (float *) dst->src[1]->data;

0 commit comments

Comments
 (0)