Skip to content

Commit 2f2d440

Browse files
committed
Speed up DeepSeek V4 prompt replay
Add a DeepSeek V4 HC weighted-sum ggml op with CPU, Metal, and meta backend support, and use it in the compressed attention path. Batch resumed compressed decode projections, reserve a resumed-prompt DeepSeek V4 graph shape, increase the compressed decode replay cap, and place server checkpoints on SWA-spaced prompt tail positions. On the Apple M3 Max test machine, the retained changes improved synthetic Metal server replay from roughly 127.8/103.4/94.7 tok/s to 165.9/127.1/113.7 tok/s, with generation sanity at about 21.5 tok/s.
1 parent 3ba61fb commit 2f2d440

19 files changed

Lines changed: 317 additions & 30 deletions

ggml/include/ggml-rpc.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ extern "C" {
88

99
#define RPC_PROTO_MAJOR_VERSION 4
1010
#define RPC_PROTO_MINOR_VERSION 0
11-
#define RPC_PROTO_PATCH_VERSION 4
11+
#define RPC_PROTO_PATCH_VERSION 5
1212

1313
#ifdef __cplusplus
14-
static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
14+
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
1515
#endif
1616

1717
#define GGML_RPC_MAX_SERVERS 16

ggml/include/ggml.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ extern "C" {
562562
GGML_OP_SOLVE_TRI,
563563
GGML_OP_GATED_DELTA_NET,
564564
GGML_OP_DSV4_HC_SPLIT_SINKHORN,
565+
GGML_OP_DSV4_HC_WEIGHTED_SUM,
565566
GGML_OP_DSV4_HC_EXPAND,
566567
GGML_OP_DSV4_FP8_KV_QUANTIZE,
567568
GGML_OP_DSV4_ROPE_TAIL,
@@ -2555,6 +2556,13 @@ extern "C" {
25552556
int sinkhorn_iters,
25562557
float eps);
25572558

2559+
// DeepSeek V4 hyperconnection weighted-sum helper.
2560+
// Computes sum_hc weights[hc, token] * x[embd, hc, token].
2561+
GGML_API struct ggml_tensor * ggml_dsv4_hc_weighted_sum(
2562+
struct ggml_context * ctx,
2563+
struct ggml_tensor * x,
2564+
struct ggml_tensor * weights);
2565+
25582566
// DeepSeek V4 hyperconnection expand helper.
25592567
// Computes post * block_out + comb^T @ residual for each token.
25602568
GGML_API struct ggml_tensor * ggml_dsv4_hc_expand(

ggml/src/ggml-backend-meta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co
958958
split_state = handle_gated_delta_net(src_ss);
959959
} break;
960960
case GGML_OP_DSV4_HC_SPLIT_SINKHORN:
961+
case GGML_OP_DSV4_HC_WEIGHTED_SUM:
961962
case GGML_OP_DSV4_HC_EXPAND:
962963
case GGML_OP_DSV4_FP8_KV_QUANTIZE:
963964
case GGML_OP_DSV4_ROPE_TAIL:

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,6 +2041,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20412041
{
20422042
ggml_compute_forward_dsv4_hc_split_sinkhorn(params, tensor);
20432043
} break;
2044+
case GGML_OP_DSV4_HC_WEIGHTED_SUM:
2045+
{
2046+
ggml_compute_forward_dsv4_hc_weighted_sum(params, tensor);
2047+
} break;
20442048
case GGML_OP_DSV4_HC_EXPAND:
20452049
{
20462050
ggml_compute_forward_dsv4_hc_expand(params, tensor);
@@ -2234,6 +2238,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22342238
case GGML_OP_SOLVE_TRI:
22352239
case GGML_OP_GATED_DELTA_NET:
22362240
case GGML_OP_DSV4_HC_SPLIT_SINKHORN:
2241+
case GGML_OP_DSV4_HC_WEIGHTED_SUM:
22372242
case GGML_OP_DSV4_HC_EXPAND:
22382243
case GGML_OP_DSV4_FP8_KV_QUANTIZE:
22392244
case GGML_OP_DSV4_ROPE_TAIL:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11084,6 +11084,54 @@ void ggml_compute_forward_dsv4_hc_split_sinkhorn(
1108411084
}
1108511085
}
1108611086

11087+
// ggml_compute_forward_dsv4_hc_weighted_sum
11088+
11089+
void ggml_compute_forward_dsv4_hc_weighted_sum(
11090+
const ggml_compute_params * params,
11091+
ggml_tensor * dst) {
11092+
const ggml_tensor * x = dst->src[0];
11093+
const ggml_tensor * weights = dst->src[1];
11094+
11095+
GGML_ASSERT(x->type == GGML_TYPE_F32);
11096+
GGML_ASSERT(weights->type == GGML_TYPE_F32);
11097+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
11098+
GGML_ASSERT(x->ne[0] == dst->ne[0]);
11099+
GGML_ASSERT(x->ne[1] == weights->ne[0]);
11100+
GGML_ASSERT(x->ne[2] == dst->ne[1]);
11101+
GGML_ASSERT(weights->ne[1] == dst->ne[1]);
11102+
GGML_ASSERT(x->ne[3] == 1);
11103+
GGML_ASSERT(weights->ne[2] == 1);
11104+
GGML_ASSERT(weights->ne[3] == 1);
11105+
GGML_ASSERT(dst->ne[2] == 1);
11106+
GGML_ASSERT(dst->ne[3] == 1);
11107+
11108+
const int64_t n_embd = dst->ne[0];
11109+
const int64_t n_hc = x->ne[1];
11110+
const int64_t n_tokens = dst->ne[1];
11111+
const int64_t n_elem = n_embd * n_tokens;
11112+
11113+
const int64_t i0 = (n_elem * params->ith) / params->nth;
11114+
const int64_t i1 = (n_elem * (params->ith + 1)) / params->nth;
11115+
11116+
const char * x_data = (const char *) x->data;
11117+
const char * w_data = (const char *) weights->data;
11118+
char * y_data = ( char *) dst->data;
11119+
11120+
for (int64_t i = i0; i < i1; ++i) {
11121+
const int64_t d = i % n_embd;
11122+
const int64_t t = i / n_embd;
11123+
11124+
float acc = 0.0f;
11125+
for (int64_t h = 0; h < n_hc; ++h) {
11126+
const float xv = *(const float *) (x_data + d*x->nb[0] + h*x->nb[1] + t*x->nb[2]);
11127+
const float wv = *(const float *) (w_data + h*weights->nb[0] + t*weights->nb[1]);
11128+
acc += xv * wv;
11129+
}
11130+
11131+
*(float *) (y_data + d*dst->nb[0] + t*dst->nb[1]) = acc;
11132+
}
11133+
}
11134+
1108711135
// ggml_compute_forward_dsv4_hc_expand
1108811136

1108911137
void ggml_compute_forward_dsv4_hc_expand(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
104104
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
105105
void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
106106
void ggml_compute_forward_dsv4_hc_split_sinkhorn(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107+
void ggml_compute_forward_dsv4_hc_weighted_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107108
void ggml_compute_forward_dsv4_hc_expand(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108109
void ggml_compute_forward_dsv4_fp8_kv_quantize(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109110
void ggml_compute_forward_dsv4_rope_tail(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_split_si
475475
return res;
476476
}
477477

478+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_weighted_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
479+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
480+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
481+
GGML_ASSERT(op->type == GGML_TYPE_F32);
482+
483+
const char * name = "kernel_dsv4_hc_weighted_sum";
484+
485+
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
486+
if (!res.pipeline) {
487+
res = ggml_metal_library_compile_pipeline(lib, name, name, nullptr);
488+
}
489+
490+
return res;
491+
}
492+
478493
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_expand(ggml_metal_library_t lib, const ggml_tensor * op) {
479494
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
480495
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad
122122
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
123123
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
124124
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_split_sinkhorn(ggml_metal_library_t lib, const struct ggml_tensor * op);
125+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_weighted_sum(ggml_metal_library_t lib, const struct ggml_tensor * op);
125126
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_expand (ggml_metal_library_t lib, const struct ggml_tensor * op);
126127
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_fp8_kv_quantize(ggml_metal_library_t lib, const struct ggml_tensor * op);
127128
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_rope_tail (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,14 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
11891189
op->src[1]->type == GGML_TYPE_F32 &&
11901190
op->src[2]->type == GGML_TYPE_F32 &&
11911191
op->type == GGML_TYPE_F32;
1192+
case GGML_OP_DSV4_HC_WEIGHTED_SUM:
1193+
return op->src[0]->type == GGML_TYPE_F32 &&
1194+
op->src[1]->type == GGML_TYPE_F32 &&
1195+
op->type == GGML_TYPE_F32 &&
1196+
op->src[0]->ne[0] == op->ne[0] &&
1197+
op->src[0]->ne[1] == op->src[1]->ne[0] &&
1198+
op->src[0]->ne[2] == op->ne[1] &&
1199+
op->src[1]->ne[1] == op->ne[1];
11921200
case GGML_OP_DSV4_HC_EXPAND:
11931201
return op->src[0]->type == GGML_TYPE_F32 &&
11941202
op->src[1]->type == GGML_TYPE_F32 &&

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,19 @@ typedef struct {
795795
float eps;
796796
} ggml_metal_kargs_dsv4_hc_split_sinkhorn;
797797

798+
typedef struct {
799+
int64_t n_embd;
800+
int64_t n_hc;
801+
int64_t n_tokens;
802+
uint64_t nb_x0;
803+
uint64_t nb_x1;
804+
uint64_t nb_x2;
805+
uint64_t nb_w0;
806+
uint64_t nb_w1;
807+
uint64_t nb0;
808+
uint64_t nb1;
809+
} ggml_metal_kargs_dsv4_hc_weighted_sum;
810+
798811
typedef struct {
799812
int64_t n_embd;
800813
int64_t n_hc;

0 commit comments

Comments
 (0)