Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2390,6 +2390,17 @@ extern "C" {
struct ggml_tensor * c,
struct ggml_tensor * parent_ids);

// dflash extension: tree-mode ssm_conv that also writes each token's
// (K-1)-element post-state to persist_inter so the driver can roll the
// live conv state back to the accepted DFS node. persist_inter must be
// contiguous F32 with shape [K-1, d_inner, n_tokens, n_seqs] (K-1 fastest).
GGML_API struct ggml_tensor * ggml_ssm_conv_tree_persist(
struct ggml_context * ctx,
struct ggml_tensor * sx,
struct ggml_tensor * c,
struct ggml_tensor * parent_ids,
struct ggml_tensor * persist_inter);

GGML_API struct ggml_tensor * ggml_ssm_scan(
struct ggml_context * ctx,
struct ggml_tensor * s,
Expand Down
105 changes: 100 additions & 5 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <vector>

// ggml_compute_forward_dup

Expand Down Expand Up @@ -9254,6 +9255,8 @@ static void ggml_compute_forward_ssm_conv_f32(
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // conv_x
const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const ggml_tensor * src2 = dst->src[2]; // parent_ids, optional tree mode
const ggml_tensor * src3 = dst->src[3]; // persist conv state, optional

const int ith = params->ith;
const int nth = params->nth;
Expand All @@ -9269,6 +9272,17 @@ static void ggml_compute_forward_ssm_conv_f32(
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));

if (src2 != nullptr) {
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src2));
GGML_ASSERT(ggml_nelements(src2) == n_t * n_s);
}
if (src3 != nullptr) {
GGML_ASSERT(src3->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src3));
GGML_ASSERT(ggml_nelements(src3) >= (int64_t)(nc - 1) * nr * n_t * n_s);
}

// rows per thread
const int dr = (nr + nth - 1)/nth;

Expand All @@ -9278,25 +9292,51 @@ static void ggml_compute_forward_ssm_conv_f32(
const int ir = ir1 - ir0;

for (int i3 = 0; i3 < n_s; ++i3) {
const int32_t * parent_ids = src2 ? (const int32_t *) src2->data + (int64_t)i3 * n_t : nullptr;
float * persist_seq = src3 ? (float *) src3->data + (int64_t)i3 * n_t * nr * (nc - 1) : nullptr;

for (int i2 = 0; i2 < n_t; ++i2) {
// {d_conv - 1 + n_t, d_inner, n_seqs}
// sliding window
const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}

int ancestors[GGML_MAX_DIMS] = {};
if (parent_ids != nullptr) {
GGML_ASSERT(nc <= GGML_MAX_DIMS);
ancestors[nc - 1] = i2;
for (int k = nc - 2; k >= 0; --k) {
const int prev = ancestors[k + 1];
ancestors[k] = prev >= 0 ? parent_ids[prev] : prev - 1;
}
}

// TODO: transpose the output for smaller strides for big batches?
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// rowwise dot product
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
float sumf = 0.0f;
float window[GGML_MAX_DIMS] = {};

// d_conv
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
const int sx_slot = parent_ids ? (nc - 1 + ancestors[i0]) : (i2 + i0);
const float s = *(const float *) ((const char *) src0->data
+ (int64_t)sx_slot * src0->nb[0]
+ (int64_t)(ir0 + i1) * src0->nb[1]
+ (int64_t)i3 * src0->nb[2]);
const float c = *(const float *) ((const char *) src1->data
+ (int64_t)i0 * src1->nb[0]
+ (int64_t)(ir0 + i1) * src1->nb[1]);
window[i0] = s;
sumf += s * c;
}
x[i1] = sumf;

if (persist_seq != nullptr) {
float * persist_token = persist_seq + ((int64_t)i2 * nr + (ir0 + i1)) * (nc - 1);
for (int i0 = 0; i0 < nc - 1; ++i0) {
persist_token[i0] = window[i0 + 1];
}
}
}
}
}
Expand Down Expand Up @@ -10439,6 +10479,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
ggml_tensor * src_parent = dst->src[6];
ggml_tensor * src_persist = dst->src[7];

const int64_t S_v = src_v->ne[0];
const int64_t H = src_v->ne[1];
Expand All @@ -10454,6 +10496,16 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(

GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
GGML_ASSERT(src_beta->ne[0] == 1);
if (src_parent != nullptr) {
GGML_ASSERT(src_parent->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src_parent));
GGML_ASSERT(ggml_nelements(src_parent) == n_tokens * n_seqs);
}
if (src_persist != nullptr) {
GGML_ASSERT(src_persist->type == GGML_TYPE_F32 || src_persist->type == GGML_TYPE_F16);
GGML_ASSERT(ggml_is_contiguous(src_persist));
GGML_ASSERT(ggml_nelements(src_persist) >= S_v * S_v * H * n_tokens * n_seqs);
}

GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
Expand All @@ -10477,8 +10529,10 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
// attn_scores: S_v * H * n_tokens * n_seqs floats
// new_states: S_v * S_v * H * n_seqs floats
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
const int64_t state_elems = S_v * S_v * H * n_seqs;
float * attn_out_base = (float *)dst->data;
float * state_out_base = (float *)dst->data + attn_score_elems;
float * inter_out_base = state_out_base + state_elems;

const float * state_in_base = (const float *)src_state->data;

Expand All @@ -10505,10 +10559,45 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
memcpy(s_out, s_in, S_v * S_v * sizeof(float));

const int32_t * parent_ids = src_parent ? (const int32_t *) src_parent->data + iv3 * n_tokens : nullptr;
auto load_inter = [&](int64_t token, int64_t elem) -> float {
if (src_persist != nullptr) {
const int64_t off = ((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem;
if (src_persist->type == GGML_TYPE_F32) {
return ((const float *) src_persist->data)[off];
}
return GGML_FP16_TO_FP32(((const ggml_fp16_t *) src_persist->data)[off]);
}
return inter_out_base[((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem];
};
auto store_inter = [&](int64_t token, int64_t elem, float value) {
if (src_persist != nullptr) {
const int64_t off = ((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem;
if (src_persist->type == GGML_TYPE_F32) {
((float *) src_persist->data)[off] = value;
} else {
((ggml_fp16_t *) src_persist->data)[off] = GGML_FP32_TO_FP16(value);
}
} else {
inter_out_base[((iv3 * n_tokens + token) * H + iv1) * S_v * S_v + elem] = value;
}
};

// attn output pointer for first token of this (head, seq)
float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;

for (int64_t t = 0; t < n_tokens; t++) {
if (parent_ids != nullptr && t > 0) {
const int32_t parent_t = parent_ids[t];
if (parent_t < 0) {
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
} else if (parent_t != t - 1) {
for (int64_t elem = 0; elem < S_v * S_v; ++elem) {
s_out[elem] = load_inter(parent_t, elem);
}
}
}

const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
Expand Down Expand Up @@ -10552,6 +10641,12 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
}

attn_data += S_v * H; // advance to next token

if (parent_ids != nullptr || src_persist != nullptr) {
for (int64_t elem = 0; elem < S_v * S_v; ++elem) {
store_inter(t, elem, s_out[elem]);
}
}
}
}
}
Expand Down
92 changes: 73 additions & 19 deletions ggml/src/ggml-cuda/fattn-chunked.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,29 @@ struct chunked_scratch {
};
static chunked_scratch g_chunked_bufs[GGML_CUDA_MAX_DEVICES];

static float * ensure_buf(float ** p, size_t * cur_bytes, size_t need_bytes) {
if (need_bytes <= *cur_bytes && *p != nullptr) return *p;
static bool try_ensure_buf(float ** p, size_t * cur_bytes, size_t need_bytes) {
if (need_bytes <= *cur_bytes && *p != nullptr) return true;
if (*p != nullptr) CUDA_CHECK(cudaFree(*p));
*p = nullptr;
CUDA_CHECK(cudaMalloc(p, need_bytes));

const cudaError_t err = cudaMalloc(p, need_bytes);
if (err != cudaSuccess) {
// Clear the sticky CUDA error so the caller can retry with a smaller
// chunk instead of aborting the whole process.
(void) cudaGetLastError();
*p = nullptr;
*cur_bytes = 0;
return false;
}

*cur_bytes = need_bytes;
return *p;
return true;
}

static void free_buf(float ** p, size_t * cur_bytes) {
if (*p != nullptr) CUDA_CHECK(cudaFree(*p));
*p = nullptr;
*cur_bytes = 0;
}

void ggml_cuda_flash_attn_ext_chunked(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand Down Expand Up @@ -104,28 +120,66 @@ void ggml_cuda_flash_attn_ext_chunked(ggml_backend_cuda_context & ctx, ggml_tens
size_t free_bytes = 0, total_bytes = 0;
CUDA_CHECK(cudaMemGetInfo(&free_bytes, &total_bytes));
const int vram_chunk = chunked_pf_compute_chunk_size(free_bytes, nh_q, nh_kv, q_batch_size, D);
const int tbq_chunk = chunked_chunk_env(vram_chunk);
int tbq_chunk = chunked_chunk_env(vram_chunk);

const int device = ctx.device;
GGML_ASSERT(device >= 0 && device < GGML_CUDA_MAX_DEVICES);
chunked_scratch & sc = g_chunked_bufs[device];

const size_t O_bytes = (size_t)nh_q * nq * D * sizeof(float);
const size_t l_bytes = (size_t)nh_q * nq * sizeof(float);
const size_t m_bytes = (size_t)nh_q * nq * sizeof(float);
const size_t S_bytes = (size_t)nh_q * q_batch_size * tbq_chunk * sizeof(float);
// Per-chunk K/V dequant: [nh_kv, tbq_chunk, D] fp32. The final chunk may
// be shorter; we still size the buffer for the max and only write chunk_len.
const size_t kv_bytes = (size_t)nh_kv * tbq_chunk * D * sizeof(float);
const size_t O_bytes = (size_t)nh_q * nq * D * sizeof(float);
const size_t l_bytes = (size_t)nh_q * nq * sizeof(float);
const size_t m_bytes = (size_t)nh_q * nq * sizeof(float);
const size_t Q_f32_bytes = (size_t)nh_q * nq * D * sizeof(float);

float * O_acc = ensure_buf(&sc.O_acc, &sc.O_bytes, O_bytes);
float * l_acc = ensure_buf(&sc.l_acc, &sc.l_bytes, l_bytes);
float * m_acc = ensure_buf(&sc.m_acc, &sc.m_bytes, m_bytes);
float * S = ensure_buf(&sc.S, &sc.S_bytes, S_bytes);
float * k_tmp = ensure_buf(&sc.k_tmp, &sc.k_bytes, kv_bytes);
float * v_tmp = ensure_buf(&sc.v_tmp, &sc.v_bytes, kv_bytes);
float * Q_f32 = ensure_buf(&sc.Q_f32, &sc.Q_bytes, Q_f32_bytes);
float * O_acc = nullptr;
float * l_acc = nullptr;
float * m_acc = nullptr;
float * S = nullptr;
float * k_tmp = nullptr;
float * v_tmp = nullptr;
float * Q_f32 = nullptr;

const int requested_tbq_chunk = tbq_chunk;
for (;;) {
const size_t S_bytes = (size_t)nh_q * q_batch_size * tbq_chunk * sizeof(float);
// Per-chunk K/V dequant: [nh_kv, tbq_chunk, D] fp32. The final chunk may
// be shorter; we still size the buffer for the max and only write chunk_len.
const size_t kv_bytes = (size_t)nh_kv * tbq_chunk * D * sizeof(float);

const bool ok =
try_ensure_buf(&sc.O_acc, &sc.O_bytes, O_bytes) &&
try_ensure_buf(&sc.l_acc, &sc.l_bytes, l_bytes) &&
try_ensure_buf(&sc.m_acc, &sc.m_bytes, m_bytes) &&
try_ensure_buf(&sc.S, &sc.S_bytes, S_bytes) &&
try_ensure_buf(&sc.k_tmp, &sc.k_bytes, kv_bytes) &&
try_ensure_buf(&sc.v_tmp, &sc.v_bytes, kv_bytes) &&
try_ensure_buf(&sc.Q_f32, &sc.Q_bytes, Q_f32_bytes);

if (ok) {
O_acc = sc.O_acc;
l_acc = sc.l_acc;
m_acc = sc.m_acc;
S = sc.S;
k_tmp = sc.k_tmp;
v_tmp = sc.v_tmp;
Q_f32 = sc.Q_f32;
break;
}

if (tbq_chunk <= CHUNKED_PF_MIN) {
GGML_ABORT("chunked prefill: failed to allocate scratch buffers");
}

tbq_chunk >>= 1;
// Release chunk-dependent scratch allocated for the failed, larger
// chunk. Otherwise retry can keep the old large buffers alive and fail
// again despite the smaller chunk size.
free_buf(&sc.S, &sc.S_bytes);
free_buf(&sc.k_tmp, &sc.k_bytes);
free_buf(&sc.v_tmp, &sc.v_bytes);
GGML_LOG_WARN("chunked prefill: scratch allocation failed, retrying with chunk=%d (requested=%d)\n",
tbq_chunk, requested_tbq_chunk);
}

cublasHandle_t cublas_handle = ctx.cublas_handle();
CUBLAS_CHECK(cublasSetStream(cublas_handle, stream));
Expand Down
Loading
Loading