diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3686af5c11c..7b4c647d975 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2390,6 +2390,17 @@ extern "C" { struct ggml_tensor * c, struct ggml_tensor * parent_ids); + // dflash extension: tree-mode ssm_conv that also writes each token's + // (K-1)-element post-state to persist_inter so the driver can roll the + // live conv state back to the accepted DFS node. persist_inter must be + // contiguous F32 with shape [K-1, d_inner, n_tokens, n_seqs] (K-1 fastest). + GGML_API struct ggml_tensor * ggml_ssm_conv_tree_persist( + struct ggml_context * ctx, + struct ggml_tensor * sx, + struct ggml_tensor * c, + struct ggml_tensor * parent_ids, + struct ggml_tensor * persist_inter); + GGML_API struct ggml_tensor * ggml_ssm_scan( struct ggml_context * ctx, struct ggml_tensor * s, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a9bc21da6f0..51a71b1a5e7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -11,6 +11,7 @@ #include #include #include +#include // ggml_compute_forward_dup @@ -9254,6 +9255,8 @@ static void ggml_compute_forward_ssm_conv_f32( ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // conv_x const ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const ggml_tensor * src2 = dst->src[2]; // parent_ids, optional tree mode + const ggml_tensor * src3 = dst->src[3]; // persist conv state, optional const int ith = params->ith; const int nth = params->nth; @@ -9269,6 +9272,17 @@ static void ggml_compute_forward_ssm_conv_f32( GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + if (src2 != nullptr) { + GGML_ASSERT(src2->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src2)); + GGML_ASSERT(ggml_nelements(src2) == n_t * n_s); + } + if (src3 != nullptr) { + GGML_ASSERT(src3->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src3)); + GGML_ASSERT(ggml_nelements(src3) >= (int64_t)(nc - 1) * nr * n_t * n_s); + } + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -9278,25 +9292,51 @@ static void ggml_compute_forward_ssm_conv_f32( const int ir = ir1 - ir0; for (int i3 = 0; i3 < n_s; ++i3) { + const int32_t * parent_ids = src2 ? (const int32_t *) src2->data + (int64_t)i3 * n_t : nullptr; + float * persist_seq = src3 ? (float *) src3->data + (int64_t)i3 * n_t * nr * (nc - 1) : nullptr; + for (int i2 = 0; i2 < n_t; ++i2) { - // {d_conv - 1 + n_t, d_inner, n_seqs} - // sliding window - const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + int ancestors[GGML_MAX_DIMS] = {}; + if (parent_ids != nullptr) { + GGML_ASSERT(nc <= GGML_MAX_DIMS); + ancestors[nc - 1] = i2; + for (int k = nc - 2; k >= 0; --k) { + const int prev = ancestors[k + 1]; + ancestors[k] = prev >= 0 ? parent_ids[prev] : prev - 1; + } + } + // TODO: transpose the output for smaller strides for big batches? // d_inner for (int i1 = 0; i1 < ir; ++i1) { // rowwise dot product // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision float sumf = 0.0f; + float window[GGML_MAX_DIMS] = {}; // d_conv for (int i0 = 0; i0 < nc; ++i0) { - sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; + const int sx_slot = parent_ids ? (nc - 1 + ancestors[i0]) : (i2 + i0); + const float s = *(const float *) ((const char *) src0->data + + (int64_t)sx_slot * src0->nb[0] + + (int64_t)(ir0 + i1) * src0->nb[1] + + (int64_t)i3 * src0->nb[2]); + const float c = *(const float *) ((const char *) src1->data + + (int64_t)i0 * src1->nb[0] + + (int64_t)(ir0 + i1) * src1->nb[1]); + window[i0] = s; + sumf += s * c; } x[i1] = sumf; + + if (persist_seq != nullptr) { + float * persist_token = persist_seq + ((int64_t)i2 * nr + (ir0 + i1)) * (nc - 1); + for (int i0 = 0; i0 < nc - 1; ++i0) { + persist_token[i0] = window[i0 + 1]; + } + } } } } @@ -10439,6 +10479,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( ggml_tensor * src_g = dst->src[3]; ggml_tensor * src_beta = dst->src[4]; ggml_tensor * src_state = dst->src[5]; + ggml_tensor * src_parent = dst->src[6]; + ggml_tensor * src_persist = dst->src[7]; const int64_t S_v = src_v->ne[0]; const int64_t H = src_v->ne[1]; @@ -10454,6 +10496,16 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v); GGML_ASSERT(src_beta->ne[0] == 1); + if (src_parent != nullptr) { + GGML_ASSERT(src_parent->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src_parent)); + GGML_ASSERT(ggml_nelements(src_parent) == n_tokens * n_seqs); + } + if (src_persist != nullptr) { + GGML_ASSERT(src_persist->type == GGML_TYPE_F32 || src_persist->type == GGML_TYPE_F16); + GGML_ASSERT(ggml_is_contiguous(src_persist)); + GGML_ASSERT(ggml_nelements(src_persist) >= S_v * S_v * H * n_tokens * n_seqs); + } GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); @@ -10477,8 +10529,10 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( // attn_scores: S_v * H * n_tokens * n_seqs floats // new_states: S_v * S_v * H * n_seqs floats const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_elems = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + float * inter_out_base = state_out_base + state_elems; const float * state_in_base = (const float *)src_state->data; @@ -10505,10 +10559,45 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + const int32_t * parent_ids = src_parent ? (const int32_t *) src_parent->data + iv3 * n_tokens : nullptr; + auto load_inter = [&](int64_t token, int64_t elem) -> float { + if (src_persist != nullptr) { + const int64_t off = ((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem; + if (src_persist->type == GGML_TYPE_F32) { + return ((const float *) src_persist->data)[off]; + } + return GGML_FP16_TO_FP32(((const ggml_fp16_t *) src_persist->data)[off]); + } + return inter_out_base[((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem]; + }; + auto store_inter = [&](int64_t token, int64_t elem, float value) { + if (src_persist != nullptr) { + const int64_t off = ((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem; + if (src_persist->type == GGML_TYPE_F32) { + ((float *) src_persist->data)[off] = value; + } else { + ((ggml_fp16_t *) src_persist->data)[off] = GGML_FP32_TO_FP16(value); + } + } else { + inter_out_base[((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem] = value; + } + }; + // attn output pointer for first token of this (head, seq) float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v; for (int64_t t = 0; t < n_tokens; t++) { + if (parent_ids != nullptr && t > 0) { + const int32_t parent_t = parent_ids[t]; + if (parent_t < 0) { + memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + } else if (parent_t != t - 1) { + for (int64_t elem = 0; elem < S_v * S_v; ++elem) { + s_out[elem] = load_inter(parent_t, elem); + } + } + } + const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1); const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1); const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); @@ -10552,6 +10641,12 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (parent_ids != nullptr || src_persist != nullptr) { + for (int64_t elem = 0; elem < S_v * S_v; ++elem) { + store_inter(t, elem, s_out[elem]); + } + } } } } diff --git a/ggml/src/ggml-cuda/fattn-chunked.cu b/ggml/src/ggml-cuda/fattn-chunked.cu index 8035b1c5b96..3240b5648cb 100644 --- a/ggml/src/ggml-cuda/fattn-chunked.cu +++ b/ggml/src/ggml-cuda/fattn-chunked.cu @@ -67,13 +67,29 @@ struct chunked_scratch { }; static chunked_scratch g_chunked_bufs[GGML_CUDA_MAX_DEVICES]; -static float * ensure_buf(float ** p, size_t * cur_bytes, size_t need_bytes) { - if (need_bytes <= *cur_bytes && *p != nullptr) return *p; +static bool try_ensure_buf(float ** p, size_t * cur_bytes, size_t need_bytes) { + if (need_bytes <= *cur_bytes && *p != nullptr) return true; if (*p != nullptr) CUDA_CHECK(cudaFree(*p)); *p = nullptr; - CUDA_CHECK(cudaMalloc(p, need_bytes)); + + const cudaError_t err = cudaMalloc(p, need_bytes); + if (err != cudaSuccess) { + // Clear the sticky CUDA error so the caller can retry with a smaller + // chunk instead of aborting the whole process. + (void) cudaGetLastError(); + *p = nullptr; + *cur_bytes = 0; + return false; + } + *cur_bytes = need_bytes; - return *p; + return true; +} + +static void free_buf(float ** p, size_t * cur_bytes) { + if (*p != nullptr) CUDA_CHECK(cudaFree(*p)); + *p = nullptr; + *cur_bytes = 0; } void ggml_cuda_flash_attn_ext_chunked(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -104,28 +120,66 @@ void ggml_cuda_flash_attn_ext_chunked(ggml_backend_cuda_context & ctx, ggml_tens size_t free_bytes = 0, total_bytes = 0; CUDA_CHECK(cudaMemGetInfo(&free_bytes, &total_bytes)); const int vram_chunk = chunked_pf_compute_chunk_size(free_bytes, nh_q, nh_kv, q_batch_size, D); - const int tbq_chunk = chunked_chunk_env(vram_chunk); + int tbq_chunk = chunked_chunk_env(vram_chunk); const int device = ctx.device; GGML_ASSERT(device >= 0 && device < GGML_CUDA_MAX_DEVICES); chunked_scratch & sc = g_chunked_bufs[device]; - const size_t O_bytes = (size_t)nh_q * nq * D * sizeof(float); - const size_t l_bytes = (size_t)nh_q * nq * sizeof(float); - const size_t m_bytes = (size_t)nh_q * nq * sizeof(float); - const size_t S_bytes = (size_t)nh_q * q_batch_size * tbq_chunk * sizeof(float); - // Per-chunk K/V dequant: [nh_kv, tbq_chunk, D] fp32. The final chunk may - // be shorter; we still size the buffer for the max and only write chunk_len. - const size_t kv_bytes = (size_t)nh_kv * tbq_chunk * D * sizeof(float); + const size_t O_bytes = (size_t)nh_q * nq * D * sizeof(float); + const size_t l_bytes = (size_t)nh_q * nq * sizeof(float); + const size_t m_bytes = (size_t)nh_q * nq * sizeof(float); const size_t Q_f32_bytes = (size_t)nh_q * nq * D * sizeof(float); - float * O_acc = ensure_buf(&sc.O_acc, &sc.O_bytes, O_bytes); - float * l_acc = ensure_buf(&sc.l_acc, &sc.l_bytes, l_bytes); - float * m_acc = ensure_buf(&sc.m_acc, &sc.m_bytes, m_bytes); - float * S = ensure_buf(&sc.S, &sc.S_bytes, S_bytes); - float * k_tmp = ensure_buf(&sc.k_tmp, &sc.k_bytes, kv_bytes); - float * v_tmp = ensure_buf(&sc.v_tmp, &sc.v_bytes, kv_bytes); - float * Q_f32 = ensure_buf(&sc.Q_f32, &sc.Q_bytes, Q_f32_bytes); + float * O_acc = nullptr; + float * l_acc = nullptr; + float * m_acc = nullptr; + float * S = nullptr; + float * k_tmp = nullptr; + float * v_tmp = nullptr; + float * Q_f32 = nullptr; + + const int requested_tbq_chunk = tbq_chunk; + for (;;) { + const size_t S_bytes = (size_t)nh_q * q_batch_size * tbq_chunk * sizeof(float); + // Per-chunk K/V dequant: [nh_kv, tbq_chunk, D] fp32. The final chunk may + // be shorter; we still size the buffer for the max and only write chunk_len. + const size_t kv_bytes = (size_t)nh_kv * tbq_chunk * D * sizeof(float); + + const bool ok = + try_ensure_buf(&sc.O_acc, &sc.O_bytes, O_bytes) && + try_ensure_buf(&sc.l_acc, &sc.l_bytes, l_bytes) && + try_ensure_buf(&sc.m_acc, &sc.m_bytes, m_bytes) && + try_ensure_buf(&sc.S, &sc.S_bytes, S_bytes) && + try_ensure_buf(&sc.k_tmp, &sc.k_bytes, kv_bytes) && + try_ensure_buf(&sc.v_tmp, &sc.v_bytes, kv_bytes) && + try_ensure_buf(&sc.Q_f32, &sc.Q_bytes, Q_f32_bytes); + + if (ok) { + O_acc = sc.O_acc; + l_acc = sc.l_acc; + m_acc = sc.m_acc; + S = sc.S; + k_tmp = sc.k_tmp; + v_tmp = sc.v_tmp; + Q_f32 = sc.Q_f32; + break; + } + + if (tbq_chunk <= CHUNKED_PF_MIN) { + GGML_ABORT("chunked prefill: failed to allocate scratch buffers"); + } + + tbq_chunk >>= 1; + // Release chunk-dependent scratch allocated for the failed, larger + // chunk. Otherwise retry can keep the old large buffers alive and fail + // again despite the smaller chunk size. + free_buf(&sc.S, &sc.S_bytes); + free_buf(&sc.k_tmp, &sc.k_bytes); + free_buf(&sc.v_tmp, &sc.v_bytes); + GGML_LOG_WARN("chunked prefill: scratch allocation failed, retrying with chunk=%d (requested=%d)\n", + tbq_chunk, requested_tbq_chunk); + } cublasHandle_t cublas_handle = ctx.cublas_handle(); CUBLAS_CHECK(cublasSetStream(cublas_handle, stream)); diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index e6ce26f7212..e5cefbb0496 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -116,7 +116,11 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, // Each successive walk beyond -1 decrements by 1, so virtual slot -k maps to // sx slot (K-1 - k), which indexes into the old state region [0, K-1). This // matches SGLang's causal_conv1d_triton HAS_EAGLE_TREE_CUSTOM_ATTN_MASK path. -template +// dflash27b_ggml: tree-mode + per-token persistent conv state. When +// WITH_PERSIST is true, every token writes its (K-1)-element conv "post-state" +// (the last K-1 cols of its parent-chain window) into persist_inter so the +// driver can roll the live conv state back to the accepted DFS node. +template static __global__ void ssm_conv_tree_f32( const float * __restrict__ src0, // sx: [K-1+n_t, d_inner, n_s] const float * __restrict__ src1, // c: [K, d_inner] @@ -125,6 +129,8 @@ static __global__ void ssm_conv_tree_f32( const int src1_nb1, float * __restrict__ dst, // [d_inner, n_t, n_s] const int dst_nb0, const int dst_nb1, const int dst_nb2, + float * __restrict__ persist_inter, // [K-1, d_inner, n_t, n_s] when WITH_PERSIST, else nullptr + const int64_t d_inner_total, // full d_inner for persist row stride const int64_t n_t) { GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; @@ -151,6 +157,10 @@ static __global__ void ssm_conv_tree_f32( const int * parent_ids_seq = parent_ids + bidx * n_t; + // Channel index this thread owns within the full d_inner dimension. + // Used both for indexing persist_inter (when enabled) and as bookkeeping. + const int channel = (int)(bidy * split_d_inner) + tid; + for (int64_t i = 0; i < n_t; i++) { // Walk the parent chain K-1 times to fill the conv window. // ancestor_virtual[k] gives the "virtual slot" for kernel position k, @@ -175,14 +185,31 @@ static __global__ void ssm_conv_tree_f32( } float sumf = 0.0f; + // Cache window values so we can both convolve and (optionally) persist + // them without re-reading from global memory. + float window[d_conv]; #pragma unroll for (size_t k = 0; k < d_conv; k++) { // Map virtual slot → sx slot: sx_slot = (K-1) + ancestors[k]. const int sx_slot = (int)(d_conv - 1) + ancestors[k]; - const float x_val = x_block[tid * stride_x + sx_slot]; - sumf += x_val * w[k]; + window[k] = x_block[tid * stride_x + sx_slot]; + sumf += window[k] * w[k]; } y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; + + if constexpr (WITH_PERSIST) { + // Per-token "post-state": the (K-1) most recent cols of this token's + // window — i.e. ancestors[1..K-1]. Layout matches the live conv state + // tensor (r_l): [K-1, d_inner, ...] with K-1 fastest. Persist memory + // layout: persist_inter[s][t][channel][k] flat = ((s*n_t + t)*d_inner + channel) * (K-1) + k. + float * persist_token = persist_inter + + ((bidx * n_t + i) * d_inner_total + channel) * (int64_t)(d_conv - 1); +#pragma unroll + for (size_t k = 0; k < d_conv - 1; k++) { + // ancestors[1] is the oldest col we keep; ancestors[K-1] = self. + persist_token[k] = window[k + 1]; + } + } } } @@ -190,7 +217,8 @@ template static void ssm_conv_tree_f32_cuda(const float * src0, const float * src1, const int * parent_ids, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, - const int dst_nb2, const int64_t nc, const int64_t nr, + const int dst_nb2, float * persist_inter, + const int64_t nc, const int64_t nr, const int64_t n_t, const int64_t n_s, cudaStream_t stream) { const int threads = 128; GGML_ASSERT(nr % threads == 0); @@ -198,9 +226,15 @@ static void ssm_conv_tree_f32_cuda(const float * src0, const float * src1, const const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); auto launch_kernel = [&](auto NC) { constexpr int kNC = decltype(NC)::value; - ssm_conv_tree_f32<<>>( - src0, src1, parent_ids, src0_nb0, src0_nb1, src0_nb2, src1_nb1, - dst, dst_nb0, dst_nb1, dst_nb2, n_t); + if (persist_inter != nullptr) { + ssm_conv_tree_f32<<>>( + src0, src1, parent_ids, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, persist_inter, nr, n_t); + } else { + ssm_conv_tree_f32<<>>( + src0, src1, parent_ids, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, nullptr, nr, n_t); + } }; switch (nc) { @@ -275,16 +309,26 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g if (parent_ids != nullptr) { GGML_ASSERT(parent_ids->type == GGML_TYPE_I32); const int * parent_ids_d = (const int *) parent_ids->data; + // dflash27b_ggml: optional src[3] = persist_inter (F32) buffer where + // each token's [K-1, d_inner] post-state is written for SSM rollback. + const struct ggml_tensor * persist_inter = dst->src[3]; + float * persist_d = nullptr; + if (persist_inter != nullptr) { + GGML_ASSERT(persist_inter->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(persist_inter)); + GGML_ASSERT(ggml_nelements(persist_inter) >= (int64_t)(nc - 1) * nr * n_t * n_s); + persist_d = (float *) persist_inter->data; + } if (fuse_silu) { ssm_conv_tree_f32_cuda(src0_d, src1_d, parent_ids_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], - nc, nr, n_t, n_s, stream); + persist_d, nc, nr, n_t, n_s, stream); } else { ssm_conv_tree_f32_cuda(src0_d, src1_d, parent_ids_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], - nc, nr, n_t, n_s, stream); + persist_d, nc, nr, n_t, n_s, stream); } return; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 26b51748cf9..76aa235c111 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5514,6 +5514,34 @@ struct ggml_tensor * ggml_ssm_conv_tree( return result; } +// dflash: tree-mode + external persistent conv post-state buffer. Same op as +// ggml_ssm_conv_tree but the kernel ALSO writes each token's (K-1)-element +// "post-state" (last K-1 cols of its parent-chain window) into persist_inter, +// matching the [K-1, d_inner, n_tokens, n_seqs] layout used by the live conv +// state in the recurrent memory. +struct ggml_tensor * ggml_ssm_conv_tree_persist( + struct ggml_context * ctx, + struct ggml_tensor * sx, + struct ggml_tensor * c, + struct ggml_tensor * parent_ids, + struct ggml_tensor * persist_inter) { + struct ggml_tensor * result = ggml_ssm_conv_tree(ctx, sx, c, parent_ids); + + GGML_ASSERT(persist_inter != NULL); + GGML_ASSERT(persist_inter->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(persist_inter)); + + const int64_t d_conv = c->ne[0]; + const int64_t d_inner = c->ne[1]; + const int64_t n_t = sx->ne[0] - d_conv + 1; + const int64_t n_s = sx->ne[2]; + GGML_ASSERT(ggml_nelements(persist_inter) >= (d_conv - 1) * d_inner * n_t * n_s); + + result->src[3] = persist_inter; + + return result; +} + // ggml_ssm_scan struct ggml_tensor * ggml_ssm_scan(