diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 59dde99e362..c9eead18aa3 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -757,14 +757,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -877,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } +bool llm_arch_supports_rs_rollback(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + return true; + default: + return false; + } +} + bool llm_arch_supports_sm_tensor(const llm_arch & arch) { switch (arch) { case LLM_ARCH_GROK: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index e37d548c98e..89cf16cc37c 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); +bool llm_arch_supports_rs_rollback(const llm_arch & arch); diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 6554a89b28a..f10397747b0 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -73,7 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, - { "hunyuan-ocr", LLM_CHAT_TEMPLATE_HUNYUAN_OCR }, + { "hunyuan-vl", LLM_CHAT_TEMPLATE_HUNYUAN_VL }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, @@ -218,7 +218,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_begin▁of▁sentence|>")) { - return LLM_CHAT_TEMPLATE_HUNYUAN_OCR; + return LLM_CHAT_TEMPLATE_HUNYUAN_VL; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -825,8 +825,8 @@ int32_t llm_chat_apply_template( ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_OCR) { - // tencent/HunyuanOCR + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_VL) { + // tencent/HunyuanOCR & tencent/HunyuanVL ss << "<|hy_begin▁of▁sentence|>"; for (size_t i = 0; i < chat.size(); i++) { std::string role(chat[i]->role); diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 13f936a946c..ea6540c0be7 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -53,7 +53,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, - LLM_CHAT_TEMPLATE_HUNYUAN_OCR, + LLM_CHAT_TEMPLATE_HUNYUAN_VL, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 3d9714ab166..ad36c06667d 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2,6 +2,7 @@ #include "ggml.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -21,6 +22,14 @@ // llama_context // +static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT; + case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP; + } + throw std::runtime_error("Unsupported ctx type"); +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : @@ -42,13 +51,22 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + cparams.n_rs_seq = params.n_rs_seq; + if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) { + LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", + __func__, cparams.n_rs_seq); + cparams.n_rs_seq = 0; + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; + cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; + cparams.embeddings_pre_norm_masked = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -65,6 +83,8 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.ctx_type = params.ctx_type; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -206,6 +226,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -278,6 +299,7 @@ llama_context::llama_context( /*.type_k =*/ params.type_k, /*.type_v =*/ params.type_v, /*.swa_full =*/ params.swa_full, + /*.ctx_type= */ cparams.ctx_type, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -860,6 +882,42 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_pre_norm() { + output_reorder(); + + return embd_pre_norm.data; +} + +float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { + output_reorder(); + + try { + if (embd_pre_norm.data == nullptr) { + throw std::runtime_error("no pre-norm embeddings"); + } + + const uint32_t n_embd = model.hparams.n_embd; + + if (!cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + } + return embd_pre_norm.data + (size_t) i * n_embd; + } + + const int64_t j = output_resolve_row(i); + return embd_pre_norm.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1040,6 +1098,13 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } +void llama_context::set_embeddings_pre_norm(bool value, bool masked) { + LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); + + cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm_masked = masked; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1072,6 +1137,19 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + if (sampler && model.split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + static bool warned = false; + if (!warned) { + LLAMA_LOG_WARN("%s: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU\n", __func__); + warned = true; + } + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); + return false; + } + const bool can_offload = sampler && sampler->iface->backend_init && @@ -1241,7 +1319,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1312,8 +1392,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1379,6 +1460,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1531,7 +1622,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::mapget_ubatch(); @@ -1689,7 +1783,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + + const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -1727,8 +1822,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1809,6 +1905,25 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + { + const bool masked = cparams.embeddings_pre_norm_masked; + const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; + const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; + + if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; + + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + } + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1823,6 +1938,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; + n_tokens_prev += ubatch.n_tokens; } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -1893,10 +2009,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_pre_norm = cparams.embeddings_pre_norm; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1908,8 +2026,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + + if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm row exists for every token in the batch, not just + // those flagged via batch.logits[i] -> size by token count instead. + embd_pre_norm.size = (size_t) n_embd * n_batch; + } // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -1925,8 +2050,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1942,6 +2067,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + embd_pre_norm.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1970,6 +2096,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + embd_pre_norm = has_embd_pre_norm ? buffer_view{(float *) (base + offset), embd_pre_norm.size} : buffer_view{nullptr, 0}; + offset += embd_pre_norm.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2034,6 +2163,12 @@ void llama_context::output_reorder() { } } + if (embd_pre_norm.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -2121,7 +2256,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3100,7 +3235,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3201,8 +3336,10 @@ llama_context_params llama_context_default_params() { /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, + /*.n_rs_seq =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, @@ -3306,6 +3443,13 @@ llama_context * llama_init_from_model( model->hparams.pooling_type, params.pooling_type); } + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + model->hparams.nextn_predict_layers == 0) { + LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); + return nullptr; + } + + try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3347,6 +3491,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) { return ctx->n_seq_max(); } +uint32_t llama_n_rs_seq(const llama_context * ctx) { + return ctx->get_cparams().n_rs_seq; +} + const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } @@ -3436,6 +3584,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_pre_norm(value, masked); +} + +float * llama_get_embeddings_pre_norm(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm(); +} + +float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm_ith(i); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 92d1b0cf95a..d03f681d4a1 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -84,6 +84,9 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_pre_norm(); + float * get_embeddings_pre_norm_ith(int32_t i); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -107,6 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_pre_norm(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); @@ -278,6 +282,11 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; + // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_pre_norm is enabled and the model graph + // sets llm_graph_result::t_h_pre_norm + buffer_view embd_pre_norm = {nullptr, 0}; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map samplers; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 9d359474132..20ec59fe335 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -12,6 +12,7 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -27,6 +28,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; @@ -40,6 +43,7 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index 8ce29d217cb..edfa71c207c 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// Set whether the context outputs pre-norm embeddings or not +// If masked == true, output the embeddings only for the tokens with batch.logits != 0 +// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); + +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index fe155c92dea..fc027de8b39 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -500,15 +500,21 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { } void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { - mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } - mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); - mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); @@ -534,14 +540,21 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { bool res = true; - res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + } - res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + } return res; } @@ -848,6 +861,9 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_h_pre_norm != nullptr) { + ggml_set_output(t_h_pre_norm); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -2528,7 +2544,8 @@ ggml_tensor * llm_graph_context::build_rs( int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); + GGML_UNUSED(rs_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 5cb1756c6a9..bf6778237e6 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -32,6 +32,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { @@ -580,7 +581,8 @@ struct llm_graph_params { ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && ( (!ubatch.token && !other.ubatch.token) || - (!ubatch.embd && !other.ubatch.embd) + (!ubatch.embd && !other.ubatch.embd) || + (ubatch.token && other.ubatch.token && ubatch.embd && other.ubatch.embd) ); // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same @@ -644,6 +646,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -672,6 +675,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map t_sampled_logits; std::map t_candidates; diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 002d15d415f..2239309c8fb 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { + if (kv_only_nextn) { + // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; + // the leading trunk blocks are not executed in this graph. + return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); + } + if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 0160a89caa2..e2d051edc6c 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -92,6 +92,8 @@ struct llama_hparams { uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; + bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) + float f_norm_eps; float f_norm_rms_eps; float f_norm_group_eps; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 10e6b459797..72f5c2fea72 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_base()->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.h b/examples/talk-llama/llama-memory-hybrid-iswa.h index 807c8aac96c..c9d3f9f57c5 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.h +++ b/examples/talk-llama/llama-memory-hybrid-iswa.h @@ -34,6 +34,7 @@ class llama_memory_hybrid_iswa : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 4ce1af592c1..33b3b395e0c 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid.h b/examples/talk-llama/llama-memory-hybrid.h index 558cafdf984..484eafb7499 100644 --- a/examples/talk-llama/llama-memory-hybrid.h +++ b/examples/talk-llama/llama-memory-hybrid.h @@ -34,6 +34,7 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index c07f1d969cb..ec5dc5835dd 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent( bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; @@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent( size = mem_size; used = 0; + this->n_rs_seq = n_rs_seq; + rs_idx.assign(n_seq_max, 0); + cells.clear(); cells.resize(mem_size); @@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); - ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -115,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } @@ -138,10 +143,11 @@ void llama_memory_recurrent::clear(bool data) { ggml_backend_buffer_clear(buf.get(), 0); } } + + std::fill(rs_idx.begin(), rs_idx.end(), 0); } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -152,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } + const bool rm_all = p0 == 0 && p1 == std::numeric_limits::max(); + if (rm_all) { + if (seq_id >= 0) { + set_rs_idx(seq_id, 0); + } else { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } + } + // models like Mamba or RWKV can't have a state partially erased at the end // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { @@ -161,10 +176,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const auto & cell = cells[tail_id]; - // partial intersection is invalid if it includes the final pos + auto & cell = cells[tail_id]; + + // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); + const llama_pos rollback = cell.pos - (p0 - 1); + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } return false; } // invalidate tails which will be cleared @@ -368,6 +389,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) { + if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) { + return; + } + rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -388,9 +416,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + if (n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } } if (ubatch.n_tokens == 0) { @@ -703,6 +737,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq GGML_UNUSED(flags); std::vector> cell_ranges; // ranges, from inclusive, to exclusive + std::vector> cell_ranges_data; // logical source row ranges uint32_t cell_count = 0; // Count the number of cells with the specified seq_id @@ -712,6 +747,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq const auto & cell = cells[i]; if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { ++cell_count; + uint32_t rs_idx_cur = 0; + + if (n_rs_seq != 0) { + if (seq_id != -1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size()); + rs_idx_cur = rs_idx[seq_id]; + } else { + bool has_rs_idx = false; + for (const llama_seq_id cell_seq_id : cell.seq_id) { + GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size()); + + const uint32_t seq_rs_idx = rs_idx[cell_seq_id]; + if (!has_rs_idx) { + rs_idx_cur = seq_rs_idx; + has_rs_idx = true; + } else if (rs_idx_cur != seq_rs_idx) { + GGML_ABORT("cannot write shared recurrent state with different rollback indices"); + } + } + } + } + + const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i); + if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) { + cell_ranges_data.emplace_back(cell_id, cell_id + 1); + } else { + cell_ranges_data.back().second++; + } + if (cell_range_begin == size) { cell_range_begin = i; } @@ -726,7 +790,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } - if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) { GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); } @@ -737,10 +801,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq } GGML_ASSERT(cell_count == cell_count_check); + cell_count_check = 0; + for (const auto & range : cell_ranges_data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + io.write(&cell_count, sizeof(cell_count)); state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + state_write_data(io, cell_ranges_data); } void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { @@ -762,6 +832,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i } throw std::runtime_error("failed to restore kv cache"); } + + if (n_rs_seq != 0) { + if (seq_id == -1) { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } else { + set_rs_idx(seq_id, 0); + } + } } void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { @@ -804,7 +882,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Write each range of cells of r_size_row length + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -825,7 +904,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Write each range of S tensor rows + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -852,9 +932,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Write GQA embedding size io.write(&n_embd_s, sizeof(n_embd_s)); - // For each row, we get the element values of each cell + // For each row, we get the element values of each logical cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; @@ -1163,5 +1242,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; + + if (mem->n_rs_seq == 0) { + return src0; + } + + uint32_t idx = 0; + if (!mem->cells[cell_idx].seq_id.empty()) { + const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin(); + if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) { + idx = mem->rs_idx[seq]; + // reset rollback idx + mem->rs_idx[seq] = 0; + } + } + return (int32_t)(idx * mem->size) + src0; } diff --git a/examples/talk-llama/llama-memory-recurrent.h b/examples/talk-llama/llama-memory-recurrent.h index 47f01d73912..b13b7b748f5 100644 --- a/examples/talk-llama/llama-memory-recurrent.h +++ b/examples/talk-llama/llama-memory-recurrent.h @@ -23,6 +23,7 @@ class llama_memory_recurrent : public llama_memory_i { bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -69,6 +70,14 @@ class llama_memory_recurrent : public llama_memory_i { uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) + // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups + uint32_t n_rs_seq = 0; + + // per-seq rollback index + std::vector rs_idx; + + void set_rs_idx(llama_seq_id seq_id, uint32_t idx); + // computed before each graph build uint32_t n = 0; diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 4a157b91fdb..4ad1612e45b 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-graph.h" #include #include @@ -20,6 +21,8 @@ struct llama_memory_params { // use full-size SWA cache bool swa_full; + + llama_context_type ctx_type; }; enum llama_memory_status { diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 4e65a45a50d..c645d0785ab 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); } if (n_tensors_moved > 0) { LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 7b3d6703c03..c476026d3e5 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -184,7 +184,7 @@ struct llama_model_loader { struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index e83056557bf..528e4c9c069 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -393,6 +393,8 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model->output); add_tensor(model->output_b); add_tensor(model->output_norm_enc); + add_tensor(model->output_s); + add_tensor(model->output_in_s); add_tensor(model->cls); add_tensor(model->cls_b); add_tensor(model->cls_out); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index ff30a2ae7a6..0d21b2a53c5 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1334,6 +1334,12 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_s && layer.ssm_beta) { layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } // input scales if (!layer.wq_in_s && layer.wq) { @@ -1393,11 +1399,30 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_in_s && layer.ssm_beta) { layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_in_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_in_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } + // output scales + if (output && output->type == GGML_TYPE_NVFP4) { + // weight scale + if (!output_s) { + output_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "scale"), {1}, TENSOR_NOT_REQUIRED); + } + // input scale + if (!output_in_s) { + output_in_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "input_scale"), {1}, TENSOR_NOT_REQUIRED); + } } } - ml.done_getting_tensors(); + GGML_ASSERT(!(output && tok_embd && + strcmp(output->name, tok_embd->name) == 0 && + output->type == GGML_TYPE_NVFP4)); // populate tensors_by_name for (auto & [_, ctx_ptr] : ml.ctx_map) { for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { @@ -1934,6 +1959,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // checks default: { + // The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain + // attention KV cache for the MTP context instead of the hybrid wrapper. + const bool mtp_on_hybrid_qwen35 = + params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE); + if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, @@ -1942,8 +1973,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, + cparams.n_rs_seq, nullptr); - } else if (llm_arch_is_hybrid(arch)) { + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -1958,6 +1990,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter_recr = [&](int32_t il) { return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; + } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter_attn = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && !hparams.is_recurrent(il); + }; + filter_recr = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && hparams.is_recurrent(il); + }; } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { @@ -1975,6 +2015,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_s */ GGML_TYPE_F32, /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -1993,6 +2034,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -2000,6 +2042,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_filter_cb filter = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { @@ -2011,6 +2054,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } + if (mtp_on_hybrid_qwen35) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); @@ -2026,7 +2074,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, cparams.n_ubatch, 1, - nullptr, + filter, reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2043,7 +2091,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, - nullptr, + filter, nullptr); } } @@ -2146,6 +2194,7 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d63c689185a..398a0aa725c 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -202,12 +202,16 @@ struct llama_layer_shortconv { }; struct llama_layer_nextn { - struct ggml_tensor * eh_proj = nullptr; - struct ggml_tensor * embed_tokens = nullptr; - struct ggml_tensor * enorm = nullptr; - struct ggml_tensor * hnorm = nullptr; - struct ggml_tensor * shared_head_head = nullptr; - struct ggml_tensor * shared_head_norm = nullptr; + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * eh_proj_s = nullptr; + struct ggml_tensor * eh_proj_in_s = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_head_s = nullptr; + struct ggml_tensor * shared_head_head_in_s = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; }; struct llama_layer { @@ -533,6 +537,11 @@ struct llama_model { struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + + // NVFP4 per-tensor scale2, input_scale for LM head + struct ggml_tensor * output_s = nullptr; + struct ggml_tensor * output_in_s = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index f43cf546ca0..a5cf148b268 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -530,6 +530,8 @@ struct llm_tokenizer_bpe : llm_tokenizer { struct llm_tokenizer_bpe_session { llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + virtual ~llm_tokenizer_bpe_session() = default; + static void append(const llama_token token_id, std::vector & output) { output.push_back(token_id); } @@ -567,7 +569,7 @@ struct llm_tokenizer_bpe_session { } } - void tokenize(const std::string & text, std::vector & output) { + virtual void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode); @@ -1579,6 +1581,88 @@ struct llm_tokenizer_plamo2_session { const llm_tokenizer_plamo2 & tokenizer; }; +// reserved suffix (U+E000) that keeps DNA k-mers distinct from identical +// base-vocab BPE tokens (e.g. CCCCCC) in token_to_id; erased from id_to_token +// text at load +static const std::string dna_kmer_marker = "\xee\x80\x80"; + +struct llm_tokenizer_hybriddna_session : llm_tokenizer_bpe_session { + llm_tokenizer_hybriddna_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector & output) override { + static const std::string open_tag = ""; + static const std::string close_tag = ""; + + const auto dna_begin_id = vocab.text_to_token(open_tag); + const auto dna_end_id = vocab.text_to_token(close_tag); + const auto dna_oov_id = vocab.text_to_token(""); + + // Fall back to plain BPE if the DNA pieces aren't in the vocab. + if (dna_begin_id == LLAMA_TOKEN_NULL || dna_end_id == LLAMA_TOKEN_NULL || dna_oov_id == LLAMA_TOKEN_NULL) { + llm_tokenizer_bpe_session::tokenize(text, output); + return; + } + + const size_t k = 6; + size_t pos = 0; + + while (pos < text.size()) { + const size_t start = text.find(open_tag, pos); + if (start == std::string::npos) { + if (pos < text.size()) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos), output); + } + break; + } + if (start > pos) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos, start - pos), output); + } + output.push_back(dna_begin_id); + + const size_t content_start = start + open_tag.size(); + const size_t end = text.find(close_tag, content_start); + const size_t content_end = (end == std::string::npos) ? text.size() : end; + + emit_dna_kmers(text.substr(content_start, content_end - content_start), k, dna_oov_id, output); + + if (end == std::string::npos) { + break; + } + output.push_back(dna_end_id); + pos = end + close_tag.size(); + } + } + +private: + void emit_dna_kmers(const std::string & raw, size_t k, llama_token oov_id, std::vector & output) { + std::string seq = raw; + for (char & c : seq) { + if (c >= 'a' && c <= 'z') { + c = char(c - 32); + } + } + + // k-mers carry the reserved marker suffix; a non-ACGT k-mer simply + // isn't in the vocab and falls back to + auto kmer_token = [&](const std::string & kmer) { + const auto tok = vocab.text_to_token(kmer + dna_kmer_marker); + return tok != LLAMA_TOKEN_NULL ? tok : oov_id; + }; + + size_t i = 0; + for (; i + k <= seq.size(); i += k) { + output.push_back(kmer_token(seq.substr(i, k))); + } + if (i < seq.size()) { + std::string kmer = seq.substr(i); + kmer.append(k - kmer.size(), 'A'); + output.push_back(kmer_token(kmer)); + } + } + + const llama_vocab & vocab; +}; + // // impl // @@ -1808,7 +1892,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_mask_id = 103; add_sep = true; - } else if (tokenizer_model == "gpt2") { + } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna") { type = LLAMA_VOCAB_TYPE_BPE; // read bpe merges and populate bpe ranks @@ -2266,6 +2350,23 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } GGML_ASSERT(id_to_token.size() == token_to_id.size()); + // hybriddna: the marker suffix kept k-mer ids distinct in token_to_id; erase + // it from id_to_token so the k-mers detokenize to the bare DNA sequence. The + // k-mers are the block right after , so only scan from there. + if (tokenizer_model == "hybriddna") { + const auto idx = token_to_id.find(""); + if (idx != token_to_id.end()) { + auto it = id_to_token.begin() + idx->second + 1; + for (; it != id_to_token.end(); ++it) { + std::string & text = it->text; + if (text.size() > dna_kmer_marker.size() + && text.compare(text.size() - dna_kmer_marker.size(), dna_kmer_marker.size(), dna_kmer_marker) == 0) { + text.erase(text.size() - dna_kmer_marker.size()); + } + } + } + } + init_tokenizer(type); // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' @@ -3144,11 +3245,19 @@ std::vector llama_vocab::impl::tokenize( } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get())); // it calls some other methods that are not exist in llm_tokenizer, // here just cast it to bpe tokenizer object + const llm_tokenizer_bpe * tok_bpe = static_cast(tokenizer.get()); + + std::unique_ptr session; + if (vocab.get_tokenizer_model() == "hybriddna") { + session = std::make_unique(vocab, *tok_bpe); + } else { + session = std::make_unique(vocab, *tok_bpe); + } + if (add_special) { - session.append_bos(output); + session->append_bos(output); } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -3161,15 +3270,15 @@ std::vector llama_vocab::impl::tokenize( #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif - session.tokenize(text, output); + session->tokenize(text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - session.append(fragment.token, output); + session->append(fragment.token, output); } } if (add_special) { - session.append_eos(output); - session.check_double_bos_eos(output); + session->append_eos(output); + session->check_double_bos_eos(output); } } break; case LLAMA_VOCAB_TYPE_WPM: diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 308e8ba9dbd..e8374c53b70 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -198,6 +198,11 @@ extern "C" { LLAMA_SPLIT_MODE_TENSOR = 3, }; + enum llama_context_type { + LLAMA_CONTEXT_TYPE_DEFAULT = 0, + LLAMA_CONTEXT_TYPE_MTP = 1, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -333,9 +338,11 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_context_type ctx_type; // set the context type (e.g. MTP) enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings @@ -530,6 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); @@ -866,7 +874,8 @@ extern "C" { // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 -// keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load) +// Keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load). +// Getting the state for a seq_id with this flag invalidates all prior states gotten for that seq_id with this flag. #define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2 typedef uint32_t llama_state_seq_flags; diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 602e3176afd..a7c77ee5d28 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -277,7 +277,7 @@ llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 136ff702957..bec7136521c 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -160,7 +160,7 @@ llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 70e86d41130..d086c4717ff 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -148,7 +148,7 @@ llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index d8653a44639..27deadffeb7 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -171,7 +171,7 @@ llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 79aa8c90899..9bd04127b25 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -193,7 +193,7 @@ llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 4e55290e4e5..4d26081cd5d 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -146,7 +146,7 @@ llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 030dd4f42a4..fe1ae10864b 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -171,7 +171,7 @@ llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index e7fe3d5b45a..2f0d44a6259 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -210,7 +210,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index b600fb0c954..30b0f3d07d0 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -142,7 +142,7 @@ llama_model_bloom::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 8510b9e29f8..4bceaefd63b 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -181,7 +181,7 @@ llama_model_chameleon::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_with_img_logits", -1); // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index e898eff7939..6766fa71c15 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -151,7 +151,7 @@ llama_model_chatglm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index e9e85d96713..274dd3342a7 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -143,7 +143,7 @@ llama_model_codeshell::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 79236121bd5..2e231bb3f93 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -150,7 +150,7 @@ llama_model_cogvlm::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/cohere2.cpp b/examples/talk-llama/models/cohere2.cpp index 12edbae1094..a514cf88fc6 100644 --- a/examples/talk-llama/models/cohere2.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -146,7 +146,7 @@ llama_model_cohere2::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index decb89f547b..adf7fcaa20f 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -131,7 +131,7 @@ llama_model_command_r::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index bce6b04bcf9..af71c775365 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -145,7 +145,7 @@ llama_model_dbrx::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 9f1a959c32c..567e3535276 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -181,7 +181,7 @@ llama_model_deci::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index c7946059662..f52ec9518b6 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -185,7 +185,7 @@ llama_model_deepseek::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp index 6bc989c9509..4f4c7cac7a8 100644 --- a/examples/talk-llama/models/delta-net-base.cpp +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -1,6 +1,7 @@ #include "models.h" #include "llama-impl.h" +#include "llama-memory-recurrent.h" // utility to get one slice from the third dimension // input dim: [x, y, c, b] @@ -397,7 +398,9 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -443,3 +446,162 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +ggml_tensor * llm_build_delta_net_base::build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il) { + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + const auto mem_size = mctx_cur->get_size(); + + const int64_t n_seqs = ubatch.n_seqs; + + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; + + const size_t row_size = ggml_row_size(conv_states_all->type, row_count); + + if (cparams.n_rs_seq == 0) { + const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0]; + const int64_t s_slot = 0; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + cb(conv_state_last, "conv_state_last", il); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, conv_states_all, + row_count, n_seqs, conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + cb(conv_state_update, "conv_state_update", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } else { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are + // inside the same ubatch. currently with `split_equal()` this is not correct + + const int64_t K = (int64_t) cparams.n_rs_seq + 1; + + for (int64_t t = 1; t <= K; ++t) { + const int64_t s_idx = std::max(0, conv_input->ne[0] - conv_states->ne[0] - K + t); + const int64_t s_slot = K - t; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, + conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } + } + + return conv_input; +} + +ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + + const int64_t S_v = s->ne[0]; + const int64_t H_v = s->ne[2]; + const int64_t n_seqs = s->ne[3]; + const int64_t n_seq_tokens = q->ne[2]; + + const bool keep = cparams.n_rs_seq > 0; + + if (!keep) { + auto attn_out = build_delta_net(q, k, v, g, b, s, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + return output; + } + + const int64_t D = S_v * S_v * H_v; + const int64_t K = cparams.n_rs_seq + 1; + + // TODO: remove pad + simplify + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); + ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); + + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad); + if (n_seq_tokens > 1) { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + } else { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il); + } + + const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, + S_v, H_v, n_seq_tokens, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * H_v), + ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), + 0); + cb(output, "attn_output", il); + + const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); + for (int64_t k_i = 0; k_i < K; ++k_i) { + const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); + ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, + S_v, S_v, H_v, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * S_v), + ggml_row_size(gdn_out->type, S_v * S_v * H_v), + ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); + + ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, + hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + ((size_t) cache_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + + return output; +} diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 93cbcf9d931..435d27281c6 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -183,7 +183,7 @@ llama_model_dots1::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 60a3f0ec285..12ac6f1ce88 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -128,7 +128,7 @@ llama_model_dream::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 2bd01a2c512..8d9ff138676 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -124,7 +124,7 @@ llama_model_ernie4_5_moe::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index fa989fe92cd..9b39c605e35 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -155,7 +155,7 @@ llama_model_ernie4_5::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 54bb3ca86b3..76d91982fc5 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -237,7 +237,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index 75d5f60631c..c7e9960d718 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -127,7 +127,7 @@ llama_model_exaone::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 5506e76424d..499e22dde81 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -163,7 +163,7 @@ llama_model_exaone4::graph::graph(const llama_model & model, const llm_gra res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index d353befdb8e..94b65a3c7c9 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -200,7 +200,7 @@ llama_model_falcon_h1::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 75f2cfef560..ad546ef2db5 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -152,7 +152,7 @@ llama_model_falcon::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 06731670007..1519682fdf6 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -130,7 +130,7 @@ llama_model_gemma::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma2.cpp b/examples/talk-llama/models/gemma2.cpp index 6255bf740fc..ae3f9ffb530 100644 --- a/examples/talk-llama/models/gemma2.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -163,7 +163,7 @@ llama_model_gemma2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // final logit soft-capping cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index ee510fe38b0..63a2b380e71 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -207,7 +207,7 @@ llama_model_gemma3::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3n.cpp b/examples/talk-llama/models/gemma3n.cpp index 881499b0ca7..6ec3a006081 100644 --- a/examples/talk-llama/models/gemma3n.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -296,7 +296,7 @@ llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); { // final logit soft-capping diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index f45ae4cad59..4f9d8b18bc7 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -380,7 +380,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 45886b51ac1..27654b8cba3 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -275,7 +275,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index d6ef76e26d6..7c242fed298 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -185,7 +185,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index ba49c31b56b..e2dcc8b1521 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -138,7 +138,7 @@ llama_model_gpt2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 33ebe2d8800..443e35addf2 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -209,7 +209,7 @@ llama_model_gptneox::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 12e4790ae24..27f6706ea10 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -186,7 +186,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits if (hparams.f_logit_scale) { diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 5e7c7b68181..cda4aa231fa 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -145,7 +145,7 @@ llama_model_granite::graph::graph( res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 0bc49d00206..7c46ec1c0f2 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -206,7 +206,7 @@ llama_model_grok::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index feef815165b..1cab75adc7f 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -184,7 +184,7 @@ llama_model_grovemoe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 44af42412f7..deb3c9671f3 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -179,7 +179,7 @@ llama_model_hunyuan_moe::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-vl.cpp b/examples/talk-llama/models/hunyuan-vl.cpp index 5fb9154bec0..da9bb74de7e 100644 --- a/examples/talk-llama/models/hunyuan-vl.cpp +++ b/examples/talk-llama/models/hunyuan-vl.cpp @@ -181,7 +181,7 @@ llama_model_hunyuan_vl::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index f0c5580a6f4..f9ee37a24b6 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -129,7 +129,7 @@ llama_model_internlm2::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index a6451dca095..2ba162605f1 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -123,7 +123,7 @@ llama_model_jais::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index ad59b953e8d..8966131441c 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -152,7 +152,7 @@ llama_model_jais2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index e1b8d137e38..84ea63c3136 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -189,7 +189,7 @@ llama_model_jamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index df6a8028736..29081344b24 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -262,7 +262,7 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index b60f67f6c4b..9722dde9f17 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -153,7 +153,7 @@ llama_model_llada_moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index fa21c5fe32c..58b2c466e17 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -147,7 +147,7 @@ llama_model_llada::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index 8ddb5936820..cef66d054b0 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -235,7 +235,7 @@ llama_model_llama::graph::graph(const llama_model & model, const llm_grap if constexpr (!embed) { // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 899611d53f6..0ff5376d571 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -260,7 +260,7 @@ llama_model_llama4::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 3dbd82fd362..84cfe399027 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -141,7 +141,7 @@ llama_model_maincoder::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index b7708d7fdd1..887a1fa509a 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -128,7 +128,7 @@ llama_model_mamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp index 71996616611..d0295ec116f 100644 --- a/examples/talk-llama/models/mimo2.cpp +++ b/examples/talk-llama/models/mimo2.cpp @@ -231,7 +231,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index ff5eb6ffa5f..1ffc54fa7c6 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -251,7 +251,7 @@ llama_model_minicpm3::graph::graph(const llama_model & model, const llm_graph_pa cb(cur, "lmhead_scaling", -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 0dee8934692..22e291d73a3 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -158,7 +158,7 @@ llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 708da49af1f..4e6ebef82cb 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -222,7 +222,7 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 6d5f18a8e20..7e551eb965b 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator + // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) std::pair build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -65,6 +65,29 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) + // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) + ggml_tensor * build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il); + + // run delta-net attention and write the new recurrent state(s) back to ssm_states_all + // s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs) + ggml_tensor * build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { @@ -1739,6 +1762,10 @@ struct llama_model_qwen35 : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; @@ -1781,6 +1808,10 @@ struct llama_model_qwen35moe : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index cfc60e8de29..0229d20ed36 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -161,7 +161,7 @@ llama_model_mpt::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index 865461f61db..a82f9c170b4 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -174,7 +174,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 0c72ed297aa..5d4a3b5c69e 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -140,7 +140,7 @@ llama_model_nemotron::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index 161035e72bc..cfcf17bcb03 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -133,7 +133,7 @@ llama_model_olmo::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 9633f269965..7cc262f5504 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -198,7 +198,7 @@ llama_model_olmo2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 4bb9013054c..7976ae44a51 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -164,7 +164,7 @@ llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 13a590ce646..15b6c8c1205 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -160,7 +160,7 @@ llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index b4128e116e7..9f76350fd4d 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -162,7 +162,7 @@ llama_model_openelm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index 7ace0a5139d..bcb4bbba4b1 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -132,7 +132,7 @@ llama_model_orion::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 1c0eadefa98..d39220bd778 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -98,7 +98,7 @@ llama_model_paddleocr::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/pangu-embed.cpp b/examples/talk-llama/models/pangu-embed.cpp index 41b7e2ac23e..7593f879b24 100644 --- a/examples/talk-llama/models/pangu-embed.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -148,7 +148,7 @@ llama_model_pangu_embed::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index a333602c72d..8f3ed5f7b7d 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -130,7 +130,7 @@ llama_model_phi2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 0a65e91fefa..f8a4a4d5aa5 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -179,7 +179,7 @@ llama_model_phi3::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cb(cur, "result_output_no_bias", -1); diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index 4c16c20a0d4..c7ed1211c31 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -127,7 +127,7 @@ llama_model_plamo::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index 29c8702606a..b713889fe72 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -185,7 +185,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); // Explicitly mark as output tensor to ensure proper backend assignment diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 849f1579e63..29f3e803d68 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -186,7 +186,7 @@ llama_model_plamo3::graph::graph(const llama_model & model, const llm_grap cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index 57f5995103b..ce050919e6a 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -204,7 +204,7 @@ llama_model_plm::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index cdc076cdf77..00467dbad7d 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -131,7 +131,7 @@ llama_model_qwen::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 6320458a13b..a5147460bae 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -141,7 +141,7 @@ llama_model_qwen2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 7587c802c68..7cb03859deb 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -184,7 +184,7 @@ llama_model_qwen2moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 1a40fa89be4..d79db682cd4 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -134,7 +134,7 @@ llama_model_qwen2vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index fa656c84ea0..41b97fed956 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -147,7 +147,7 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index f276be61ba8..04ecc18fcdc 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -12,16 +12,22 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -29,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -43,50 +54,85 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -111,7 +157,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -128,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -160,6 +208,13 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -167,7 +222,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -297,8 +352,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -328,41 +381,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -413,7 +439,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -424,18 +450,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -471,3 +486,151 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // hparams.n_layer includes both main model layers and MTP layers. The MTP + // layer is stored immediately after the main layers in model.layers[]. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index cf05dc9d61c..dc24f6ed537 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -15,16 +15,22 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -32,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -46,60 +57,105 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags); // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -124,7 +180,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -141,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -173,6 +231,13 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -180,7 +245,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -310,8 +375,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -341,41 +404,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -426,7 +462,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -437,18 +473,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -525,3 +550,183 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 4440b83aa45..a4f8e1379c9 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -168,7 +168,7 @@ llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index cb1b4814caf..1d873427db5 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -176,7 +176,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 7871f8f7952..5defd893944 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -163,7 +163,7 @@ llama_model_qwen3vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3vlmoe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp index b99143c8908..5b77df57122 100644 --- a/examples/talk-llama/models/qwen3vlmoe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -180,7 +180,7 @@ llama_model_qwen3vlmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index f14f10917ff..bf3949a9092 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -150,7 +150,7 @@ llama_model_refact::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index 325ee73ba5c..ca8e009615e 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -167,7 +167,7 @@ llama_model_rnd1::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 2944711acec..ba2a9dfa0db 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -176,7 +176,7 @@ llama_model_rwkv6::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index 6f7d1f5722f..566b8cdcb54 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -158,7 +158,7 @@ llama_model_rwkv6qwen2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index b205e3935e1..7574b252621 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -202,7 +202,7 @@ llama_model_rwkv7::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 83e114740b6..806cba574be 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -141,7 +141,7 @@ llama_model_seed_oss::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 3214e7cbad3..4231cccc666 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -178,7 +178,7 @@ llama_model_smallthinker::graph::graph(const llama_model & model, const ll res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 7adaf34c534..90e7d473eaf 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -143,7 +143,7 @@ llama_model_smollm3::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 8f613e55947..4da7f7aefcf 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -163,7 +163,7 @@ llama_model_stablelm::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index 58cf0ac0edc..e131af058bc 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -135,7 +135,7 @@ llama_model_starcoder::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index 45dae0602d4..9c207c02885 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -148,7 +148,7 @@ llama_model_starcoder2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp index c4789752d21..3b68e68707a 100644 --- a/examples/talk-llama/models/step35.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -261,7 +261,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 27a0711ba41..73e32741406 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -265,7 +265,7 @@ llama_model_t5::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index a873e5d2e8f..214fed99bad 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -253,7 +253,7 @@ llama_model_wavtokenizer_dec::graph::graph(const llama_model & model, const llm_ LLM_NORM, -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index e4d111e622a..d6d1c7a2e5d 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -126,7 +126,7 @@ llama_model_xverse::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index dc13e53f09f..b02ecdc930f 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -605,6 +605,136 @@ static std::vector unicode_regex_split_custom_qwen2(const std::string & return bpe_offsets; } +// Qwen3.5 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +// Compared to Qwen2, letter-runs also consume Unicode combining marks (\p{M}): [\p{L}\p{M}]+ instead of \p{L}+ +static std::vector unicode_regex_split_custom_qwen35(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?[\p{L}\p{M}]+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || flags.is_accent_mark || _get_flags(pos + 1).is_accent_mark || _get_flags(pos+1).is_letter) { + pos++; + while (_get_flags(pos).is_letter || _get_flags(pos).is_accent_mark) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{M}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + template static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) { using BidirIt = typename std::basic_string::const_iterator; @@ -929,6 +1059,9 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if ( regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen35(text, offsets); } else if (regex_expr == "\\p{Han}+") { // K2's first pattern - handle all K2 patterns together bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4e65cd68b4e..03020888f97 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 11) +set(GGML_VERSION_MINOR 12) set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") @@ -249,6 +249,7 @@ option(GGML_SYCL "ggml: use SYCL" option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON) +option(GGML_SYCL_SUPPORT_LEVEL_ZERO "ggml: use Level Zero API in SYCL backend" ON) option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index 91c9d5cd343..23a3066f56d 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -6,6 +6,7 @@ include(CMakeFindDependencyMacro) find_dependency(Threads) if (NOT GGML_SHARED_LIB) + set(GGML_BASE_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_OPTIONS "") @@ -20,7 +21,15 @@ if (NOT GGML_SHARED_LIB) if (GGML_OPENMP_ENABLED) find_dependency(OpenMP) - list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + set(GGML_OPENMP_INTERFACE_LINK_LIBRARIES "") + if (TARGET OpenMP::OpenMP_C) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C) + endif() + if (TARGET OpenMP::OpenMP_CXX) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_CXX) + endif() + list(APPEND GGML_BASE_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) endif() if (GGML_CPU_HBM) @@ -122,7 +131,8 @@ if(NOT TARGET ggml::ggml) add_library(ggml::ggml-base UNKNOWN IMPORTED) set_target_properties(ggml::ggml-base PROPERTIES - IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}" + INTERFACE_LINK_LIBRARIES "${GGML_BASE_INTERFACE_LINK_LIBRARIES}") set(_ggml_all_targets "") if (NOT GGML_BACKEND_DL) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h index 78aa059dde3..a7926a21a9a 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h @@ -76,6 +76,7 @@ GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_i // Utils // Create a buffer and allocate all the tensors in a ggml_context // ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft +// ggml_backend_alloc_ctx_tensors_from_buft returns NULL on failure or if all tensors in ctx are already allocated or zero-sized GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3357a0d9985..f6725265504 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1189,8 +1189,8 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // a - x - // b - dy + // a - dy + // b - x GGML_API struct ggml_tensor * ggml_silu_back( struct ggml_context * ctx, struct ggml_tensor * a, @@ -2541,6 +2541,11 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + // + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): + // K == 1: output carries the final state only. + // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 3e48860bfc8..c26c3f1470d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -222,6 +222,23 @@ if (GGML_SCHED_NO_REALLOC) target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC) endif() +if (GGML_OPENMP) + find_package(OpenMP) + if (OpenMP_FOUND) + set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") + else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") + message(WARNING "OpenMP not found") + endif() +else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") +endif() + +if (GGML_OPENMP_ENABLED) + target_compile_definitions(ggml-base PRIVATE GGML_USE_OPENMP) + target_link_libraries(ggml-base PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) +endif() + add_library(ggml ggml-backend-dl.cpp ggml-backend-reg.cpp) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index a4b01ccf8a1..3bda9abbe03 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -150,7 +150,7 @@ static void ggml_dyn_tallocr_insert_block(struct tallocr_chunk * chunk, size_t o static void ggml_dyn_tallocr_remove_block(struct tallocr_chunk * chunk, int idx) { // shift all elements after idx by 1 to the left, overwriting the element at idx - for (int i = idx; i < chunk->n_free_blocks; i++) { + for (int i = idx; i < chunk->n_free_blocks - 1; i++) { chunk->free_blocks[i] = chunk->free_blocks[i+1]; } chunk->n_free_blocks--; diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index c0ffd9a048b..5f9ae9c1bc5 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, + // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; @@ -1273,6 +1275,9 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg for (size_t j = 0; j < n_bufs; j++) { ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } const size_t simple_offset = i_start * chunk_size_j; ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1380,6 +1385,9 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co for (size_t j = 0; j < n_bufs; j++){ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } const size_t simple_offset = i_start * chunk_size_j; ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1443,6 +1451,7 @@ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_bac buf_ctx->buf_configs.reserve(n_simple_bufts); for (size_t i = 0; i < n_simple_bufts; i++) { ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); + GGML_ASSERT(simple_buf != nullptr); max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); } @@ -1472,8 +1481,27 @@ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struc t->data = (void *) 0x2000000000000000; // FIXME } for (size_t i = 0; i < n_simple_bufts; i++) { - meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft( - meta_buf_ctx->buf_configs[i].ctx, ggml_backend_meta_buft_simple_buft(buft, i)); + ggml_context * ctx = meta_buf_ctx->buf_configs[i].ctx; + ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i); + + // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL. + // For those edge cases, allocate a dummy buffer instead. + bool any_nonzero_slice = false; + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + if (ggml_nelements(t) != 0) { + any_nonzero_slice = true; + break; + } + } + if (any_nonzero_slice) { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft); + } else { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_buft_alloc_buffer(simple_buft, 0); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf_ctx->buf_configs[i].buf; + } + } + GGML_ASSERT(meta_buf_ctx->buf_configs[i].buf != nullptr); meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); } return meta_buf; @@ -1603,6 +1631,9 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1644,6 +1675,9 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -2140,4 +2174,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; return backend_ctx->backend_configs[index].backend; } - diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 4e36909f45e..87615921c09 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -306,7 +306,7 @@ void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_ GGML_ASSERT(tensor); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + if (n_copies <= 1 || backend->iface.get_tensor_2d_async == NULL) { for (size_t i = 0; i < n_copies; i++) { ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); } @@ -317,7 +317,7 @@ void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_ } GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); } @@ -379,7 +379,7 @@ void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf != NULL && "tensor buffer not set"); - if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + if (n_copies <= 1 || buf->iface.get_tensor_2d == NULL) { for (size_t i = 0; i < n_copies; i++) { ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); } diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 869c7b238bf..8c735a045b3 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -72,17 +72,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() endif() - if (GGML_OPENMP) - find_package(OpenMP) - if (OpenMP_FOUND) - set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") - target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) - - target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) - else() - set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") - message(WARNING "OpenMP not found") - endif() + if (GGML_OPENMP_ENABLED) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) + target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) endif() if (GGML_LLAMAFILE) @@ -450,12 +442,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/repack.cpp ) if (GGML_CPU_RISCV64_SPACEMIT) + include(ggml-cpu/cmake/FindSMTIME.cmake) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) list(APPEND GGML_CPU_SOURCES ggml-cpu/spacemit/ime.cpp ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/spine_mem_pool.cpp + ggml-cpu/spacemit/spine_mem_pool.h + ggml-cpu/spacemit/repack.cpp + ggml-cpu/spacemit/repack.h + ggml-cpu/spacemit/ime_env.cpp + ggml-cpu/spacemit/ime_env.h ggml-cpu/spacemit/ime1_kernels.cpp + ggml-cpu/spacemit/ime2_kernels.cpp ggml-cpu/spacemit/ime_kernels.h + ggml-cpu/spacemit/rvv_kernels.cpp + ggml-cpu/spacemit/rvv_kernels.h ) endif() if(NOT GGML_CPU_ALL_VARIANTS) @@ -485,6 +487,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + if (GGML_RV_ZBA) + string(APPEND MARCH_STR "_zba") + endif() if (GGML_CPU_RISCV64_SPACEMIT) # `xsmtvdotii' is only required for GCC >= 15. if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND diff --git a/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake new file mode 100644 index 00000000000..c8a4d4b4ec9 --- /dev/null +++ b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake @@ -0,0 +1,32 @@ +include(CheckCSourceRuns) + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv)" AND GGML_CPU_RISCV64_SPACEMIT) + set(SMT_MARCH_STR "-march=rv64gcv_zfh_zvfh_zba_zicbop") + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND SMT_MARCH_STR "_xsmtvdotii") + endif() + set(CMAKE_REQUIRED_FLAGS "${SMT_MARCH_STR}") + + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vfwmadot v2, v0, v1, fp16\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFWMADOT_FP16) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot1 v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOTN) + check_c_source_compiles("int main() {__asm__ volatile(\"vpack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK) + check_c_source_compiles("int main() {__asm__ volatile(\"vnspack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + unset(CMAKE_REQUIRED_FLAGS) + + list(APPEND RISCV64_SPACEMIT_IME_SPEC "") + if (SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + set(RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME1") + endif() + + if (SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4 AND SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK AND SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + list(APPEND RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME2") + endif() + + message("RISCV64_SPACEMIT_IME_SPEC: ${RISCV64_SPACEMIT_IME_SPEC}") +endif() diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b7acafdaa8..cd5c61a8187 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -50,6 +50,10 @@ #include "llamafile/sgemm.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ime.h" +#endif + // Note: once we move threading into a separate C++ file // will use std::hardware_destructive_interference_size instead of hardcoding it here // and we'll use C++ attribute syntax. @@ -2939,7 +2943,9 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - cur = S_v * sizeof(float) * n_tasks; + const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { @@ -3011,7 +3017,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const struct ggml_cgraph * cgraph = tp->cgraph; const struct ggml_cplan * cplan = tp->cplan; +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(state->ith); +#else set_numa_thread_affinity(state->ith); +#endif struct ggml_compute_params params = { /*.ith =*/ state->ith, @@ -3068,6 +3078,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_barrier(state->threadpool); +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(state->ith); +#endif + return 0; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6bc8dc150ce..7485ba4fc86 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // scratch layout per thread: [delta(S_v)] - const int64_t scratch_per_thread = S_v; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int64_t K = src_state->ne[1]; + GGML_ASSERT(K >= 1); + // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; - float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; + float * state_work = K > 1 ? (delta + S_v) : nullptr; // output layout: [attn_scores | new_states] - // attn_scores: S_v * H * n_tokens * n_seqs floats - // new_states: S_v * S_v * H * n_seqs floats - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K)) + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int64_t shift = n_tokens - K; + const float * state_in_base = (const float *)src_state->data; //const int64_t rq1 = nev1 / neq1; @@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; - float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. + // For K>1, work in scratch and copy out per-token when the slot is in range. + float * s_out = (K > 1) + ? state_work + : state_out_base + (iv3 * H + iv1) * S_v * S_v; - // copy input state into output buffer and operate in-place - const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + // copy input state into the working buffer and operate in-place + // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); // attn output pointer for first token of this (head, seq) @@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (K > 1) { + const int64_t target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state_o = state_out_base + target_slot * state_size_per_snap + + (iv3 * H + iv1) * S_v * S_v; + memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float)); + } + } } } } diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 91fe1925eaa..9563ea3e4bd 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -3,19 +3,32 @@ #include "ime.h" +#include "binary-ops.h" +#include "common.h" #include "ggml-backend-impl.h" #include "ggml-common.h" #include "ggml-cpu.h" +#include "ime_env.h" #include "ime_kernels.h" +#include "ops.h" +#include "repack.h" +#include "rvv_kernels.h" +#include "spine_mem_pool.h" #include "traits.h" +#include "vec.h" + +#include +#include +#include #include +#include #include +#include #include #include // for GGML_ASSERT #include #include - // clang-format off #if defined(__riscv) @@ -25,13 +38,17 @@ #include #endif -#if !defined(__riscv_zfh) -#error "riscv zfh extension not enabled" +#if !defined(__riscv_zfh) || !defined(__riscv_zvfh) +#error "riscv zfh extension not enabled, GGML_RV_ZFH and GGML_RV_ZVFH must be defined to 1" #endif -#if defined(RISCV64_SPACEMIT_IME1) +#if !defined(__riscv_zba) +#error "riscv zba extension not enabled, GGML_RV_ZBA must be defined to 1" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) || defined(RISCV64_SPACEMIT_IME2) #else -#error "RISCV64_SPACEMIT_IME1 not defined" +#error "RISCV64_SPACEMIT_IME1 or RISCV64_SPACEMIT_IME2 not defined" #endif #else @@ -46,382 +63,490 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif -#if defined(RISCV64_SPACEMIT_IME1) -#define QGEMM_STRIDEN_THREAD_ALIGN 16 -#else -#define QGEMM_STRIDEN_THREAD_ALIGN 32 -#endif - // clang-format on -struct qnbitgemm_spacemit_ime_args { - const float * a_ptr = nullptr; - size_t lda = 0; - const std::byte * packed_quant_b_data = nullptr; - const float * quant_b_scale = nullptr; - const void * quant_b_zp = nullptr; - const float * quant_b_blksum = nullptr; - const float * bias = nullptr; - float * c_ptr = nullptr; - size_t ldc = 0; -}; - -constexpr size_t div_round_up(size_t up, size_t down) { - return (up + down - 1) / down; -} - -constexpr size_t q8_blk_size(size_t blk_len) { - const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t); - // Currently, the strictest alignment requirement of a block is for a float. - // Ensure contiguous blocks are suitably aligned. - assert(blk_size % alignof(float) == 0); - return blk_size; +extern "C" { +extern void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value); +extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value); } namespace ggml::cpu::riscv64_spacemit { -const int num_ai_cores = std::thread::hardware_concurrency() / 2; - -} // namespace ggml::cpu::riscv64_spacemit +struct TLSContext { + int cpu_id{ -1 }; + cpu_set_t cpuset; + void * tcm_buffer{ nullptr }; + size_t tcm_buffer_size{ 0 }; +}; -static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len, - const size_t gemm_k, - const qnbitgemm_spacemit_ime_args * gemm_args, - void * const per_gemm_ws, - const size_t m_start, - const size_t m_count, - const size_t n_start, - const size_t n_count) { - constexpr size_t scale_stride = sizeof(uint16_t); - constexpr size_t blk_bitwidth = 4; +thread_local TLSContext tls_context; + +template constexpr size_t get_repacked_block_type_size() { + if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(block_q8_0); + } else if constexpr (std::is_same_v) { + return sizeof(block_q4_0) * INTER_SIZE / QK4_0; + } else if constexpr (std::is_same_v || std::is_same_v) { + return (sizeof(block_q4_0) + sizeof(uint8_t)) * INTER_SIZE / QK4_1; + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q2_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q3_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_mxfp4<1>); + } else if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_1<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_0<1>); + } else { + assert(false); + return 0; + } +} - const size_t k_blks = div_round_up(gemm_k, blk_len); +template constexpr bool block_type_has_zp() { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return false; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return true; + } else { + assert(false); + return false; + } +} - const size_t lda = k_blks * q8_blk_size(blk_len); - const size_t ldc = gemm_args->ldc; - const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8); - const std::byte * quant_a_ptr = static_cast(per_gemm_ws) + m_start * lda; +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(ggml_tensor * t, const void * data, size_t data_size) = 0; +}; - const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0; - const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); - const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride; +template class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start; + size = GGML_PAD(size, sizeof(int64_t)); - size_t count_n = 0; - const size_t compute_block_count_n = m_count == 1 ? n_count : 16; - for (size_t n = 0; n < n_count; n += count_n) { - count_n = std::min(n_count - n, compute_block_count_n); + return true; + } + case GGML_OP_MUL_MAT_ID: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - const std::byte * a_row = quant_a_ptr; - const std::byte * b_col = packed_quant_b_data + n * packed_b_stride; - const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr; - float * c_blk = c_ptr + n; + size = GGML_PAD(size, sizeof(int64_t)); - int32_t rows_remaining = m_count; + const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert + const int64_t ne12 = op->src[1]->ne[2]; // n_tokens - while (rows_remaining > 0) { - const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4( - blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr, - scale_stride); + const size_t sizeof_mmid_row_mapping = sizeof(int64_t); + size += sizeof_mmid_row_mapping * ne02 * (ne12 + 1) + (ne02 + 1) * sizeof(int64_t); - c_blk += rows_handled * ldc; - a_row += rows_handled * lda; + size = GGML_PAD(size, sizeof(int64_t)); - rows_remaining -= rows_handled; + return true; + } + default: + // GGML_ABORT("fatal error"); + break; } + return false; } -} -template constexpr int QK_0() { - if constexpr (K == 4) { - return QK4_0; - } - if constexpr (K == 8) { - return QK8_0; + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT"); + return false; + } + break; + case GGML_OP_MUL_MAT_ID: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT_ID"); + return false; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; } - return -1; -} -template struct block { - ggml_half d[N]; // deltas for N qK_0 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks -}; + void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; -template struct block_with_zp { - ggml_half d[N]; // deltas for N qK_1 blocks - uint8_t zp[N]; // zero points for N qK_1 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks -}; + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; -// control size -static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); -static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), - "wrong block_with_zp<4,16> size/padding"); -static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + GGML_TENSOR_BINARY_OP_LOCALS -using block_q4_0x16 = block<4, 16>; -using block_q4_1x16 = block_with_zp<4, 16>; -using block_q8_0x16 = block<8, 16>; + int ith = params->ith; + int nth = params->nth; -static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x16 out; - GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + [[maybe_unused]] const enum ggml_type type = src0->type; - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } + void * w_data = (void *) src0->data; + const float * feature = (const float *) src1->data; + float * output = (float *) dst->data; - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + const int64_t gemm_m = ne11 * ne12 * ne13; + const int64_t gemm_k = ne10; + const int64_t gemm_n = ne01; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::quantize_a_row_def quantize_a_4row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + bool set_kernel_impl = false; + + int64_t block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + set_kernel_impl = true; + } } - } +#endif - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::ime1::quantize_a_4row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); } - } - return out; -} + const int64_t a_k_blks = spacemit_kernels::div_round_up(gemm_k, a_blk_len); + const int64_t b_k_blks = spacemit_kernels::div_round_up(gemm_k, b_blk_len); -static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { - block_q4_1x16 out; - GGML_ASSERT(QK4_1 / blck_size_interleave == 2); - - for (int i = 0; i < 16; i++) { - float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); - float mid = -std::nearbyintf(m / d); - mid = std::min(15.0f, std::max(0.0f, mid)); - out.d[i] = GGML_FP32_TO_FP16(d); - out.zp[i] = static_cast(mid); - } + const int64_t row_stride_a = a_k_blks * block_stride_a; + const int64_t gemm_workspace_size = GGML_PAD(gemm_m * row_stride_a, alignof(int64_t)); - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + if (ith == 0 && params->wsize < gemm_workspace_size) { + GGML_ABORT("wsize less than gemm_workspace_size"); } - } - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); - } - } + uintptr_t ws_ptr = reinterpret_cast(params->wdata); - return out; -} + void * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const int64_t tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; -static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - GGML_ASSERT(interleave_block == 16); + auto * quant_a_buffer = reinterpret_cast(ws_ptr); - constexpr int nrows_interleaved = 16; + constexpr int64_t row_align = 4; + const int64_t row_blks = spacemit_kernels::div_round_up(gemm_m, row_align); - block_q4_0x16 * dst = (block_q4_0x16 *) t->data; - const block_q4_0 * src = (const block_q4_0 *) data; - block_q4_0 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; + const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const int64_t per_mb_rows_wsize = row_align * row_stride_a; + const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + const int64_t barrier_idx = static_cast(ith / 2); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { - return -1; - } + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + if (gemm_m == 1) { + int task_per_thread = spacemit_kernels::div_round_up(a_k_blks, nth); + int a_blk_start = ith * task_per_thread; + int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks); + if (a_blk_start < a_blk_end) { + quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len, + quant_a_buffer + a_blk_start * block_stride_a); + } + } else { + int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth); + int m_row_blk_start = ith * task_per_thread; + int m_row_blk_end = std::min(m_row_blk_start + task_per_thread, (int) row_blks); + for (int m_row_blk = m_row_blk_start; m_row_blk < m_row_blk_end; m_row_blk++) { + int m_idx = m_row_blk * row_align; + int rows_tobe_handled = (gemm_m - m_idx) > row_align ? row_align : (gemm_m - m_idx); + + if (rows_tobe_handled == row_align && quantize_a_4row_i8 != nullptr) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_4row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + } else { + while (rows_tobe_handled) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + rows_tobe_handled -= 1; + m_idx += 1; + } + } } - *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; - } - return 0; - GGML_UNUSED(data_size); -} + ggml_barrier(params->threadpool); -static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_1); - GGML_ASSERT(interleave_block == 16); + const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16; + const int64_t gemm_m_blocked = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t max_gemm_n_stride = spacemit_kernels::div_round_up(gemm_n * gemm_m_blocked, nth); - constexpr int nrows_interleaved = 16; + int64_t gemm_n_stride = gemm_n; + if (max_gemm_n_stride < gemm_n) { + gemm_n_stride = + std::min(gemm_n_stride, spacemit_kernels::div_round_up(max_gemm_n_stride, NB_COLS) * NB_COLS); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_1 * src = (const block_q4_1 *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_1; + if (gemm_n_stride == gemm_n && tcm_buffer != nullptr && per_mb_rows_wsize <= tcm_buffer_size) { + for (int64_t m_start = ith * row_align; m_start < gemm_m; m_start += row_align * nth) { + uint8_t * b_col = reinterpret_cast(w_data); + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + int64_t m_row_real = std::min(gemm_m - m_start, row_align); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { - return -1; - } + spacemit_kernels::rvv::memcpy1d(tcm_buffer, quant_a_buffer + m_start * row_stride_a, + m_row_real * row_stride_a); - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < gemm_n; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(gemm_n - ni, (int64_t) NB_COLS); + + uint8_t * a_row_ptr = (uint8_t *) tcm_buffer; + float * c_blk = output + m_start * gemm_n + ni; + + int32_t rows_remaining = m_row_real; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_ptr, b_col, b_col_zp, c_blk, rows_remaining, + n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_ptr += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + } } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; + } else if (tcm_buffer != nullptr && per_nb_cols_wsize <= tcm_buffer_size) { + uint8_t * a_row = quant_a_buffer; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((gemm_workspace_size + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + gemm_workspace_size; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; - GGML_UNUSED(data_size); -} + int64_t ni = ith * NB_COLS; + int64_t nb_real = std::min(gemm_n - ni, NB_COLS); -static inline void get_scale_min_k4(int j, - const uint8_t * GGML_RESTRICT q, - uint8_t * GGML_RESTRICT d, - uint8_t * GGML_RESTRICT m) { - if (j < 4) { - *d = q[j] & 63; - *m = q[j + 4] & 63; - } else { - *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); - *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); - } -} + if (ith % 2 == 0 && nb_real > 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + } -static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 16); - GGML_ASSERT(QK_K / QK4_1 == 8); + spine_barrier_wait(cur_barrier); - constexpr int nrows_interleaved = 16; + if (ith % 2 != 0 && nb_real > 0) { + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_K * src = (const block_q4_K *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + for (; ni < gemm_n; ni += NB_COLS * nth) { + int64_t rows_remaining = gemm_m; + float * c_blk = output + ni; + auto * a_row_cur = a_row; - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { - return -1; - } + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int j = 0; j < 8; j++) { - for (int i = 0; i < nrows_interleaved; i++) { - uint8_t sc, m; - const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - const float min = - GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); - get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); - const float d1 = d * sc; - const float m1 = min * m; - - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); - // src -> [b0, b32] [b1, b33] ... [b31, b63] - // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] - const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; - if (j % 2 == 0) { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); - } - } else { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); - } - } + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining, + nb_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_cur += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS * nth; + if (next_ni < gemm_n) { + nb_real = std::min(gemm_n - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + next_ni * row_stride_b, + nb_real * row_stride_b); } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); } - } - src += nrows_interleaved * nblocks; - } - return 0; + } else { + const int64_t task_count_m = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t task_count_n = spacemit_kernels::div_round_up(gemm_n, gemm_n_stride); - GGML_UNUSED(data_size); -} + int64_t task_count = task_count_m * task_count_n; + int64_t task_per_thread = (task_count + nth - 1) / nth; + int64_t start = ith * task_per_thread; + int64_t end = std::min((ith + 1) * task_per_thread, task_count); + for (int64_t compute_idx = start; compute_idx < end; compute_idx++) { + const auto tid_n = compute_idx / task_count_m; + const auto tid_m = compute_idx % task_count_m; -namespace ggml::cpu::riscv64_spacemit { + const int64_t m_start = tid_m * gemm_m_stride; + const int64_t m_count = std::min(gemm_m - m_start, (int64_t) gemm_m_stride); -template -int repack(struct ggml_tensor *, const void *, size_t); + const int64_t n_start = tid_n * gemm_n_stride; + const int64_t n_count = std::min(gemm_n - n_start, (int64_t) gemm_n_stride); -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); -} + const int64_t n_blk = m_count == 1 ? n_count : NB_COLS; -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); -} + uint8_t * b_col = reinterpret_cast(w_data) + n_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); -} + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < n_count; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(n_count - ni, n_blk); -class tensor_traits_base : public ggml::cpu::tensor_traits { - public: - virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; -}; + uint8_t * a_row = quant_a_buffer + m_start * row_stride_a; -template class tensor_traits : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; - size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); - return true; - default: - // GGML_ABORT("fatal error"); - break; - } - return false; - } + float * c_blk = output + m_start * gemm_n + n_start + ni; - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - if (op->src[0]->type == GGML_TYPE_Q4_0 || // - op->src[0]->type == GGML_TYPE_Q4_1 || // - op->src[0]->type == GGML_TYPE_Q4_K) { - forward_mul_mat_q4(params, op); - return true; + int64_t rows_remaining = m_count; + + uint8_t * b_col_cur = b_col; + uint8_t * b_col_zp_cur = b_col_zp; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row, b_col_cur, b_col_zp_cur, c_blk, + rows_remaining, n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } } - default: - // GGML_ABORT("fatal error"); - break; + } } - return false; } - void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) { + void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; + const ggml_tensor * src0 = op->src[0]; const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * ids = op->src[2]; ggml_tensor * dst = op; GGML_TENSOR_BINARY_OP_LOCALS @@ -429,133 +554,381 @@ template class tensor_ int ith = params->ith; int nth = params->nth; - [[maybe_unused]] const enum ggml_type type = src0->type; + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + spacemit_kernels::moe_gemm_kernel_quantize_def moe_gemm_kernel_m2; + bool set_kernel_impl = false; + size_t block_stride_a = spacemit_kernels::q8_blk_size(QK4_0); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(QK4_0, true); + + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i5; + set_kernel_impl = true; + } + } +#endif - void * w_data = (void *) src0->data; - const float * feature = (const float *) src1->data; - float * output = (float *) dst->data; +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); + } - const size_t batch_feature = ne12 * ne13; - [[maybe_unused]] const size_t batch_weight = ne02 * ne03; - const size_t gemm_m = ne11; - const size_t gemm_k = ne10; - const size_t gemm_n = ne01; + const size_t a_k_blks = spacemit_kernels::div_round_up(ne10, a_blk_len); + const size_t b_k_blks = spacemit_kernels::div_round_up(ne10, b_blk_len); - GGML_ASSERT(batch_weight == 1); + const size_t nbw1 = a_k_blks * block_stride_a; + const size_t nbw2 = ne11 * nbw1; + const size_t nbw3 = nbw2 * ne12; + const size_t gemm_workspace_size = GGML_PAD(nbw3, alignof(int64_t)); - const size_t block_count_k = div_round_up(gemm_k, QK4_0); - const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0); - const size_t per_gemm_workspace_stride = - div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t); - const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride; - const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1; + const uintptr_t ws_ptr = reinterpret_cast(params->wdata); + auto * quant_a_buffer = reinterpret_cast(ws_ptr); - if (ith == 0 && params->wsize < desired_wsize) { - throw std::runtime_error("wsize less than desired_wsize"); + if (ne11 == 1) { + for (int64_t ii = ith; ii < ne12 * a_k_blks; ii += nth) { + int64_t i12 = ii / a_k_blks; + int64_t ak_blk_id = ii % a_k_blks; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12) + ak_blk_id * a_blk_len, + a_blk_len, quant_a_buffer + i12 * nbw2 + ak_blk_id * block_stride_a); + } + } else { + for (int64_t ii = ith; ii < ne12 * ne11; ii += nth) { + int64_t i12 = ii / ne11; + int64_t i11 = ii % ne11; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12 + i11 * nb11), ne10, + quant_a_buffer + i12 * nbw2 + i11 * nbw1); + } } - std::vector qnbitgemm_args(batch_feature); +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) *ne12 + (i1)] - for (size_t i = 0; i < batch_feature; i++) { - qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i; - qnbitgemm_args[i].lda = gemm_k; - qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data; - qnbitgemm_args[i].quant_b_scale = nullptr; + int64_t * matrix_row_counts = (int64_t *) (ws_ptr + gemm_workspace_size); + int32_t * valid_ep_count = (int32_t *) (matrix_row_counts + n_as); + int32_t * valid_act_count = (int32_t *) (valid_ep_count + 1); + int64_t * valid_matrix_row_counts = (int64_t *) (valid_act_count + 1); + mmid_row_mapping * matrix_rows = (mmid_row_mapping *) (valid_matrix_row_counts + n_as); - if constexpr (std::is_same_v) { - qnbitgemm_args[i].quant_b_zp = nullptr; - } else { - qnbitgemm_args[i].quant_b_zp = w_data; + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + + // group rows by src0 matrix + for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int32_t id = 0; id < n_ids; ++id) { + const int32_t i02 = + *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; + matrix_row_counts[i02] += 1; + } } - qnbitgemm_args[i].bias = nullptr; - qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i; - qnbitgemm_args[i].ldc = gemm_n; + int32_t valid_ep_count_t = 0; + int32_t valid_act_count_t = 0; + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + valid_matrix_row_counts[valid_ep_count_t] = cur_a; + valid_act_count_t += cne1; + valid_ep_count_t += 1; + } + valid_ep_count[0] = valid_ep_count_t; + valid_act_count[0] = valid_act_count_t; } - const uintptr_t ws_ptr = reinterpret_cast(params->wdata); - void * ws = reinterpret_cast((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1))); - const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0); + const int64_t barrier_idx = static_cast(ith / 2); - { - constexpr size_t block_size_m = 4; - size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m); - int32_t task_count = batch_feature * per_gemm_block_count_m; - int32_t task_per_thread = (task_count + nth - 1) / nth; - int32_t start = ith * task_per_thread; - int32_t end = std::min((ith + 1) * task_per_thread, task_count); - for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { - int32_t gemm_idx = compute_idx / per_gemm_block_count_m; - int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m; - int32_t m_idx = block_idx_in_gemm * block_size_m; - const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; - int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); - - if (rows_tobe_handled == block_size_m) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = - static_cast(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - } else { - while (rows_tobe_handled) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = static_cast(ws) + - gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - rows_tobe_handled -= 1; - m_idx += 1; + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; + + ggml_barrier(params->threadpool); + + const size_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const size_t expert_b_stride = ne01 * row_stride_b; + const size_t per_nb_cols_wsize = NB_COLS * row_stride_b; + + std::array src_workspaces; + std::array dst_workspaces; + + auto * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const auto tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; + + const auto valid_ep_count_t = valid_ep_count[0]; + const auto valid_act_count_t = valid_act_count[0]; + + int nth_es = 1; + int nth_n = nth; + + int ith_es = ith % nth_es; + int ith_n = (ith / nth_es) % nth_n; + + if (valid_ep_count_t % nth == 0 && tcm_buffer != nullptr && valid_ep_count_t == n_as && + valid_act_count_t == n_as && per_nb_cols_wsize <= tcm_buffer_size) { + for (int64_t valid_id = ith; valid_id < valid_ep_count_t; valid_id += nth) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride; + + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, 0); + const int id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + const int64_t i1 = id; + const int64_t i2 = i12; + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + float * c_blk = (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)); + + uint8_t * a_row = src1_col; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((nbw1 + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + nbw1; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; + + if (ith % 2 == 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + } + + spine_barrier_wait(cur_barrier); + + if (ith % 2 != 0) { + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + } + + int64_t nb_real = std::min(ne01, NB_COLS); + for (int64_t ni = 0; ni < ne01; ni += NB_COLS) { + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } + + gemm_kernel(b_blk_len, a_row, b_col, b_col_zp, c_blk + ni, 1, nb_real, b_k_blks, ne01); + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS; + if (next_ni < ne01) { + nb_real = std::min(ne01 - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d( + b_col, reinterpret_cast(src0_cur) + next_ni * row_stride_b, per_nb_cols_wsize); } } } - } + } else { + for (int64_t valid_id = ith_es; valid_id < valid_ep_count_t; valid_id += nth_es) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + const int64_t cne1 = matrix_row_counts[cur_a]; - ggml_barrier(params->threadpool); + int64_t src1_cur_start = 0; + int64_t src1_cur_end = cne1; - if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) { - return; - } - nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores }); - - size_t threads_per_gemm = nth / batch_feature; - constexpr size_t gemm_m_stride = 128; - size_t nc = gemm_n; - const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride); - const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm); - if (max_nc < nc) { - nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN); - } - const size_t gemm_n_stride = nc; - const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride); - const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride); - threads_per_gemm = thread_count_m * thread_count_n; + int64_t src0_cur_start = (ith_n * ne01) / nth_n; + int64_t src0_cur_end = MIN(((ith_n + 1) * ne01) / nth_n, ne01); - { - int task_count = batch_feature * threads_per_gemm; - int task_per_thread = (task_count + nth - 1) / nth; - int start = ith * task_per_thread; - int end = std::min((ith + 1) * task_per_thread, task_count); - for (int compute_idx = start; compute_idx < end; compute_idx++) { - const auto gemm_i = compute_idx / threads_per_gemm; - const auto blk_i = compute_idx % threads_per_gemm; - const auto * data = &qnbitgemm_args[gemm_i]; + if (src1_cur_start >= src1_cur_end || src0_cur_start >= src0_cur_end) { + continue; + } + + src0_cur_start = + (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = + (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride + src0_cur_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + + size_t extra_tcm_buffer_size = tcm_buffer_size; + void * extra_tcm_buffer = tcm_buffer; + if (tcm_buffer != nullptr && (src1_cur_end - src1_cur_start) >= 4 && + (src0_cur_end - src0_cur_start) * row_stride_b <= tcm_buffer_size) { + spacemit_kernels::rvv::memcpy1d(tcm_buffer, src0_cur, + (src0_cur_end - src0_cur_start) * row_stride_b); + src0_cur = reinterpret_cast(tcm_buffer); + b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + extra_tcm_buffer_size -= (src0_cur_end - src0_cur_start) * row_stride_b; + extra_tcm_buffer = reinterpret_cast(reinterpret_cast(tcm_buffer) + + (src0_cur_end - src0_cur_start) * row_stride_b); + } - const auto tid_n = blk_i / thread_count_m; - const auto tid_m = blk_i % thread_count_m; + int ir1 = src1_cur_start; - const size_t m_start = tid_m * gemm_m_stride; - const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride); + if (extra_tcm_buffer_size >= nbw1 && extra_tcm_buffer != nullptr) { + int64_t quant_a_tile_size = extra_tcm_buffer_size / nbw1; + do { + quant_a_tile_size = MIN(quant_a_tile_size, src1_cur_end - ir1); - const size_t n_start = tid_n * gemm_n_stride; - const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride); + uint8_t * quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); - void * per_gemm_ws = reinterpret_cast(ws) + gemm_i * per_gemm_workspace_stride; + int iir1 = ir1; + for (; iir1 < (ir1 + quant_a_tile_size); ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, iir1); - sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + spacemit_kernels::rvv::memcpy1d(quant_a_tile_buffer, src1_col, nbw1); + quant_a_tile_buffer = quant_a_tile_buffer + nbw1; + } + + quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); + iir1 = ir1; + + if (moe_gemm_kernel_m2 != nullptr) { + for (; iir1 < (ir1 + quant_a_tile_size - 1); iir1 += 2, quant_a_tile_buffer += 2 * nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + mmid_row_mapping row_mapping_1 = MMID_MATRIX_ROW(cur_a, iir1 + 1); + + src_workspaces[0] = quant_a_tile_buffer; + src_workspaces[1] = quant_a_tile_buffer + nbw1; + + dst_workspaces[0] = + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start; + dst_workspaces[1] = (float *) ((char *) dst->data + + ((row_mapping_1.i1) * nb1 + (row_mapping_1.i2) * nb2)) + + src0_cur_start; + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, + ne01); + } + } + + for (; iir1 < (ir1 + quant_a_tile_size); iir1++, quant_a_tile_buffer += nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + + gemm_kernel( + b_blk_len, quant_a_tile_buffer, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start, + 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + + ir1 += quant_a_tile_size; + } while (ir1 < src1_cur_end); + } else { + if (moe_gemm_kernel_m2 != nullptr) { + for (; ir1 < src1_cur_end - 1; ir1 += 2) { + for (int iir1 = 0; iir1 < 2; ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1 + iir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + src_workspaces[iir1] = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + dst_workspaces[iir1] = + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start; + } + + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } + + for (; ir1 < src1_cur_end; ir1++) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + gemm_kernel(b_blk_len, src1_col, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, 1, + src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } } } +#undef MMID_MATRIX_ROW } - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + int repack(ggml_tensor * t, const void * data, size_t data_size) override { GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), (int) NB_COLS, (int) INTER_SIZE); return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); @@ -563,309 +936,464 @@ template class tensor_ }; class tensor_traits_common : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + bool work_size(int n_threads, const ggml_tensor * op, size_t & size) override { switch (op->op) { - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - size = 0; + case GGML_OP_FLASH_ATTN_EXT: + { + const int n_tasks = n_threads; + const int64_t neq2 = op->src[0]->ne[2]; // number of query heads + const int64_t DK = op->src[1]->ne[0]; + const int64_t DV = op->src[2]->ne[0]; // DV + + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float) * + (GGML_FA_TILE_Q * DK + 2 * GGML_FA_TILE_Q * GGML_FA_TILE_KV + GGML_FA_TILE_Q * DV + + GGML_FA_TILE_KV * DV + GGML_FA_TILE_KV * DK) * + n_tasks; + + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float) * (neq2 * n_chunks * (2 + DV) + n_tasks * (DK + 2 * DV)); + + size = MAX(prefill, decode); + } return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { switch (op->op) { case GGML_OP_NORM: - forward_norm_f32(params, op); - return true; + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } case GGML_OP_RMS_NORM: - forward_rms_norm_f32(params, op); + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_rms_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } + case GGML_OP_ADD: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_add(params, op); + return true; + } + case GGML_OP_SUB: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_sub(params, op); + return true; + } + case GGML_OP_MUL: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_mul(params, op); + return true; + } + case GGML_OP_DIV: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_div(params, op); + return true; + } + case GGML_OP_FLASH_ATTN_EXT: + forward_flash_attn_ext_f16(params, op); + return true; + case GGML_OP_CONT: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] != src0->nb[0] && op->nb[0] == src0->nb[1] && + op->ne[3] * op->ne[2] * op->nb[2] == src0->ne[3] * src0->ne[2] * src0->nb[2]) { + spacemit_kernels::rvv::forward_cont_with_permute(params, op); + } else { + ggml_compute_forward_cont(params, op); + } + return true; + } + case GGML_OP_CPY: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] == src0->nb[1] && src0->nb[0] != src0->nb[1] && + ggml_nelements(src0) == ggml_nelements(op)) { + spacemit_kernels::rvv::forward_cpy_with_permute(params, op); + } else { + ggml_compute_forward_cpy(params, op); + } + return true; + } + case GGML_OP_REPEAT: + { + const bool rows_equal = ggml_nrows(op->src[0]) == ggml_nrows(op); + const bool broadcast_or_equal = op->src[0]->ne[0] == 1 || op->src[0]->ne[0] == op->ne[0]; + + if (rows_equal && broadcast_or_equal) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + default: + break; + } + } + + if (op->src[0]->ne[1] == 1 && op->src[0]->ne[0] == op->ne[0]) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_repeat(params, op); + } + return true; + case GGML_OP_SUM_ROWS: + { + if (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) { + spacemit_kernels::rvv::forward_sum_rows(params, op); + } else { + ggml_compute_forward_sum_rows(params, op); + } + } + return true; + case GGML_OP_GET_ROWS: + { + if (op->src[0]->type == op->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_get_rows(params, op); + } return true; + case GGML_OP_CONCAT: + { + const int32_t dim = ggml_get_op_params_i32(op, 0); + if (dim == 0 && op->type == op->src[0]->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_concat(params, op); + } + return true; + // TODO For GGML_OP_GATED_DELTA_NET + // case GGML_OP_GATED_DELTA_NET: + // return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); + void forward_flash_attn_ext_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + + const bool supported_prec = (dst->op_params[3] == GGML_PREC_F32 || dst->op_params[3] == GGML_PREC_DEFAULT); + const bool supported_types = (q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16); + const bool supported_shape = (DK > 0 && DK <= 128 && DV > 0 && DV <= 128); + const bool supported_vlen = (__riscv_vlenb() == 128); + + if (!(supported_prec && supported_types && supported_shape && supported_vlen)) { + ggml_compute_forward_flash_attn_ext(params, dst); + return; + } + + // total rows in q + const int64_t nr = neq1 * neq2 * neq3; + // rows per thread const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool use_tiled = !params->use_ref && (neq1 >= Q_TILE_SZ); - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); + // 4x chunks per thread + // int nth_scaled = nth * 4; + // int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + // int64_t nchunk = (nr + chunk_size - 1) / chunk_size; - GGML_ASSERT(epsilon > 0.0f); + // if (nth == 1 || nchunk < nth) { + // nchunk = nth; + // } - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; + int64_t nchunk = nth; - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_chunk_set(params->threadpool, nth); + } - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); + ggml_barrier(params->threadpool); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - size_t gvl = __riscv_vsetvlmax_e32m4(); - vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + // The number of elements in each chunk + const int64_t dr = (nr + nchunk - 1) / nchunk; - sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); - p_input += gvl; - p_temp_output += gvl; - length -= gvl; + if (use_tiled) { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_tiled_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); + } else { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); } - gvl = __riscv_vsetvlmax_e32m1(); - - float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - vfloat32m1_t mean_v = - __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); - mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); - mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); - mean /= hidden_size; - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; - mean_square = sqrt(mean_square - mean * mean + epsilon); - - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; - - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; - } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } } - void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); - - GGML_ASSERT(epsilon > 0.0f); - - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; - - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); - - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - - size_t gvl = __riscv_vsetvlmax_e32m4(); - // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + int repack(ggml_tensor * t, const void * data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } +}; - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); +// Impl By IME1 +static const tensor_traits q4_0_16x32_q8_0; +static const tensor_traits q4_1_16x32_q8_0; +static const tensor_traits q4_k_16x32_q8_0; +// Impl By IME2 +static const tensor_traits q2_k_32x256_q8_0; +static const tensor_traits q3_k_32x256_q8_0; +static const tensor_traits q4_0_32x32_q8_0; +static const tensor_traits q4_1_32x32_q8_0; +static const tensor_traits q4_0_32x256_q8_0; +static const tensor_traits q4_1_32x256_q8_0; +static const tensor_traits q4_k_32x32_q8_0; +static const tensor_traits q6_k_32x32_q8_0; +static const tensor_traits q8_0_32x32_q8_0; +static const tensor_traits mxfp4_32x32_q8_0; +static const tensor_traits q5_k_32x32_q8_0; +static const tensor_traits q5_1_32x32_q8_0; +static const tensor_traits q5_0_32x32_q8_0; +// Impl By RVV +static const tensor_traits_common rvv_impl; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); +} // namespace ggml::cpu::riscv64_spacemit - p_input += gvl; - p_temp_output += gvl; - length -= gvl; +static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const ggml_tensor * cur) { + switch (cur->type) { + case GGML_TYPE_Q2_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q2_k_32x256_q8_0; + } +#endif } + break; + case GGML_TYPE_Q3_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q3_k_32x256_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0; + } - gvl = __riscv_vsetvlmax_e32m1(); - - // float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x32_q8_0; + } +#endif - mean_square = sqrt(mean_square + epsilon); +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + // (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::q4_1_32x256_q8_0; + // } + + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_1_32x32_q8_0; + } +#endif - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_k_32x32_q8_0; + } +#endif - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x32_q8_0; } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q6_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q6_k_32x32_q8_0; } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q8_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q8_0_32x32_q8_0; } +#endif } - } - } - - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { - memcpy(t->data, data, data_size); - return 0; - } -}; - -static const tensor_traits q4_0_16x8_q8_0; -static const tensor_traits q4_1_16x8_q8_0; -static const tensor_traits q4_k_16x8_q8_0; -static const tensor_traits_common rvv_impl; - -} // namespace ggml::cpu::riscv64_spacemit - -static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) { - if (cur->type == GGML_TYPE_Q4_0) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_1) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_K) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_F32) { - return &ggml::cpu::riscv64_spacemit::rvv_impl; + break; + case GGML_TYPE_MXFP4: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::mxfp4_32x32_q8_0; + // } +#endif + } + break; + case GGML_TYPE_Q5_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_k_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_1_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_0_32x32_q8_0; + } +#endif + } + break; + default: + break; } return nullptr; } static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor) { + ggml_tensor * tensor) { tensor->extra = (void *) const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); @@ -874,8 +1402,46 @@ static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_ba return GGML_STATUS_SUCCESS; } +static void ggml_backend_riscv64_spacemit_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + if (base == nullptr) { + return; + } + + ggml::cpu::riscv64_spacemit::spine_mem_pool_free(base); +} + +static void * ggml_backend_riscv64_spacemit_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + return base; +} + +static void ggml_backend_riscv64_spacemit_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + GGML_ASSERT(tensor); + memset((char *) tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_riscv64_spacemit_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + memset(base, value, buffer->size); +} + static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor, + ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -891,6 +1457,20 @@ static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_ GGML_UNUSED(buffer); } +static const ggml_backend_buffer_i ggml_backend_riscv64_spacemit_buffer_i = { + /* .free_buffer = */ ggml_backend_riscv64_spacemit_buffer_free_buffer, + /* .get_base = */ ggml_backend_riscv64_spacemit_buffer_get_base, + /* .init_tensor = */ ggml_backend_riscv64_spacemit_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_riscv64_spacemit_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_riscv64_spacemit_buffer_set_tensor, + /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_riscv64_spacemit_buffer_clear, + /* .reset = */ nullptr, +}; + static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { return "CPU_RISCV64_SPACEMIT"; @@ -899,18 +1479,12 @@ static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_ static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - - if (buffer == nullptr) { + void * base = ggml::cpu::riscv64_spacemit::spine_mem_pool_alloc(size, 64); + if (base == nullptr) { return nullptr; } - buffer->buft = buft; - buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; - buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; - buffer->iface.get_tensor = nullptr; - buffer->iface.cpy_tensor = nullptr; - return buffer; + return ggml_backend_buffer_init(buft, ggml_backend_riscv64_spacemit_buffer_i, base, size); } static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { @@ -919,44 +1493,91 @@ static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_b GGML_UNUSED(buft); } -static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, - const struct ggml_tensor * tensor) { +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { if (tensor->ne[i] <= 0) { return 0; } } - size_t nbytes; + GGML_UNUSED(buft); + + const auto plain_nbytes = [&]() { + size_t total = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * tensor->nb[i]; + } + return total; + }; + const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { - nbytes = ggml_type_size(tensor->type); - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + return plain_nbytes(); + } + + const size_t row_nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + + const auto add_strided_nbytes = [&](size_t total, size_t src_block_size, size_t dst_block_size) { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * (tensor->nb[i] / src_block_size) * dst_block_size; } - } else { - nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; - if (tensor->type == GGML_TYPE_Q4_K) { - GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); - nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - } - } else { - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; - } + return total; + }; + + const auto remap_block_nbytes = [&](size_t src_block_size, size_t dst_block_size, int64_t padded_rows = 0) { + GGML_ASSERT(row_nbytes % src_block_size == 0); + + size_t total = + add_strided_nbytes((row_nbytes / src_block_size) * dst_block_size, src_block_size, dst_block_size); + + if (padded_rows > 0 && tensor->ne[1] % padded_rows != 0) { + total += (padded_rows - tensor->ne[1] % padded_rows) * (tensor->nb[1] / src_block_size) * dst_block_size; } + + return total; + }; + + size_t nbytes = row_nbytes; + switch (tensor->type) { + case GGML_TYPE_Q4_K: + nbytes = remap_block_nbytes(sizeof(block_q4_K), sizeof(block_q4_1) * 8); + break; + case GGML_TYPE_Q6_K: + nbytes = remap_block_nbytes(sizeof(block_q6_K), sizeof(block_q8_0) * 8, 32); + break; + case GGML_TYPE_Q8_0: + nbytes = remap_block_nbytes(sizeof(block_q8_0), sizeof(block_q8_0), 32); + break; + case GGML_TYPE_Q2_K: + nbytes = remap_block_nbytes(sizeof(block_q2_K), sizeof(spacemit_kernels::nrow_block_q2_k<1>)); + break; + case GGML_TYPE_Q3_K: + nbytes = remap_block_nbytes(sizeof(block_q3_K), sizeof(spacemit_kernels::nrow_block_q3_k<1>)); + break; + case GGML_TYPE_MXFP4: + nbytes = remap_block_nbytes(sizeof(block_mxfp4), sizeof(spacemit_kernels::nrow_block_mxfp4<1>)); + break; + case GGML_TYPE_Q5_K: + nbytes = remap_block_nbytes(sizeof(block_q5_K), sizeof(spacemit_kernels::nrow_block_q5_1<1>) * 8); + break; + case GGML_TYPE_Q5_1: + nbytes = remap_block_nbytes(sizeof(block_q5_1), sizeof(spacemit_kernels::nrow_block_q5_1<1>)); + break; + case GGML_TYPE_Q5_0: + nbytes = remap_block_nbytes(sizeof(block_q5_0), sizeof(spacemit_kernels::nrow_block_q5_0<1>)); + break; + default: + nbytes = add_strided_nbytes(row_nbytes, 1, 1); + break; } - GGML_UNUSED(buft); return nbytes; } namespace ggml::cpu::riscv64_spacemit { class extra_buffer_type : ggml::cpu::extra_buffer_type { - bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + bool supports_op(ggml_backend_dev_t, const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && @@ -970,10 +1591,16 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } } break; - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - if (op->src[0]->type == GGML_TYPE_F32) { - return true; + case GGML_OP_MUL_MAT_ID: + if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } } break; default: @@ -983,15 +1610,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return false; } - ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + ggml::cpu::tensor_traits * get_tensor_traits(const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_REPEAT: + case GGML_OP_SUM_ROWS: + case GGML_OP_GET_ROWS: + case GGML_OP_CONCAT: + // case GGML_OP_GATED_DELTA_NET: return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl); default: // GGML_ABORT("fatal error"); @@ -1005,7 +1645,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } // namespace ggml::cpu::riscv64_spacemit ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + static ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { /* .iface = */ { /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, @@ -1023,3 +1663,78 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { return &ggml_backend_cpu_buffer_type_riscv64_spacemit; } + +extern "C" { +static int bind_ai_thread() { + int fd, bytes; + char str[32]; + + fd = open("/proc/set_ai_thread", O_WRONLY); + if (fd < 0) { + GGML_LOG_ERROR("try open /proc/set_ai_thread failed\n"); + return -1; + } + + snprintf(str, 16, "%d", 0); + bytes = write(fd, str, strlen(str)); + if (bytes < 0) { + GGML_LOG_ERROR("try write /proc/set_ai_thread failed\n"); + close(fd); + return -1; + } + + close(fd); + return 0; +} + +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n) { + int cpu_id = sched_getcpu(); + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2 && + !((1 << cpu_id) & ggml::cpu::riscv64_spacemit::global_spine_env_info.cpu_mask)) { + GGML_PRINT_DEBUG("bind_ai_thread for thread %d, pid %d\n", thread_n, getpid()); + bind_ai_thread(); + } + + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_tcm && + ggml::cpu::riscv64_spacemit::tls_context.cpu_id == -1) { + CPU_ZERO(&(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + pthread_t main_thread = pthread_self(); + const auto & perfer_core_ids = ggml::cpu::riscv64_spacemit::global_spine_env_info.perfer_core_ids; + if (thread_n < 0 || static_cast(thread_n) >= perfer_core_ids.size()) { + GGML_ABORT("thread_n %d exceeds perfer_core_ids size %zu\n", thread_n, perfer_core_ids.size()); + } + auto perfer_cpu_id = perfer_core_ids[static_cast(thread_n)]; + CPU_SET(perfer_cpu_id, &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + int s = + pthread_setaffinity_np(main_thread, sizeof(cpu_set_t), &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + if (s != 0) { + GGML_ABORT("set thread affinity error for thread_n %d, cpu_id %d\n", thread_n, perfer_cpu_id); + } + + int ai_cpu_id = perfer_cpu_id - ggml::cpu::riscv64_spacemit::global_spine_env_info.aicpu_id_offset; + ggml::cpu::riscv64_spacemit::tls_context.cpu_id = ai_cpu_id; + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_get(ai_cpu_id); + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size = + ggml::cpu::riscv64_spacemit::global_spine_env_info.tcm_blk_size; + } + + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + void * rt = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_wait(ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt == nullptr) { + GGML_ABORT("wait tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n) { + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + auto rt = ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_release( + ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt != 0) { + GGML_ABORT("release tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} +} diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h index 800d91acdae..6849dd95e05 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.h +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -8,6 +8,14 @@ extern "C" { ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n); + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n); + +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment); + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp index cbbb6cd9160..6acc6819dfb 100644 --- a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -1,8 +1,26 @@ +#include "ggml-impl.h" #include "ggml.h" #include "ime_kernels.h" +#include "rvv_kernels.h" #include #include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +# error "RISCV64_SPACEMIT_IME1 not defined" +#endif // clang-format off #if defined(__GNUC__) @@ -11,7 +29,7 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif // clang-format on -namespace sqnbitgemm_spacemit_ime { +namespace spacemit_kernels { #define QUANTIZEM4ROW_KERNEL \ "vmv.s.x v16, zero \n\t" \ @@ -76,1093 +94,208 @@ namespace sqnbitgemm_spacemit_ime { "vse8.v v31, (s1) \n\t" namespace ime1 { -void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + uint8_t * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, %[BlkLen], e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "sub t2, t2, t0 \n\t" - "slli t1, t0, 2 \n\t" - "add %[SRC], %[SRC], t1 \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE - - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL - - "addi t3, %[BlkLen], 0 \n\t" - "addi s2, s1, 0 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "SET_ZERO%=: \n\t" - "vse8.v v8, (s2) \n\t" - "addi s2, s2, 32 \n\t" - "addi t3, t3, -8 \n\t" - "bnez t3, SET_ZERO%= \n\t" - - QUANTIZEM4ROW_STORE - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 128) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "addi t2, t2, -128 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v16, v24, v16 \n\t" - "vfredmax.vs v24, v16, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e64, m4 \n\t" - "vsse64.v v16, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "sub t2, t2, t2 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "jal x0, QUANTIZE%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 256) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "addi t2, t2, -256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v24, v16 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t1, t2, 0 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - - "TAIL_LOOP%=: \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t2, e32, m1 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 32 \n\t" - "vfmul.vf v1, v0, f11 \n\t" - "vfcvt.x.f.v v2, v1 \n\t" - "vsetvli t0, zero, e16, mf2 \n\t" - "vnclip.wx v3, v2, zero \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vnclip.wx v3, v3, zero \n\t" - "vse8.v v3, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bnez t2, TAIL_LOOP%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL + + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" + + QUANTIZEM4ROW_STORE + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), [CountK] "r"(CountK), + [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); } } -void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { const float * SRC = A; - std::byte * DST = QuantA; + uint8_t * DST = QuantA; constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + uint8_t * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; - if (CountK <= BlkLen) { - float max_abs_A = 0.0f; - for (size_t k = 0; k < CountK; k++) { - max_abs_A = std::max(max_abs_A, fabsf(A[k])); - } - float scale_A = max_abs_A * range_max_reciprocal; - - ((float *) QuantA)[0] = scale_A; - - auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float)); - - for (size_t k = 0; k < CountK; k++) { - QuantAData_offset[k] = - (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits::lowest(), - (float) std::numeric_limits::max()); - } - for (size_t k = CountK; k < BlkLen; k++) { - QuantAData_offset[k] = 0; - } - - return; - } - - if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { - __asm__ volatile( - "vsetvli t0, zero, e8, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "LOOP%=: \n\t" - "vsetvli t0, %[CNT], e8, m8 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "sub %[CNT], %[CNT], t0 \n\t" - "bnez %[CNT], LOOP%= \n\t" - : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset) - : - : "cc", "t0"); - } - if (BlkLen == 16) { - float buffer[64] = { 0.0f }; - __asm__ volatile( - "addi t3, zero, 16*8 \n\t" - "addi t2, zero, 16 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m2 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v2, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v4, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v6, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v10, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v12, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v14, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v18, v2 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v22, v6 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v26, v10 \n\t" - "vfabs.v v28, v12 \n\t" - "vfabs.v v30, v14 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v18, v18, v19 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v22, v22, v23 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v26, v26, v27 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vse32.v v16, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v18, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v20, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v22, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v24, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v26, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v28, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v30, (a1) \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f11, f3, f7 \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f11, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f12, f3, f7 \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fsw f12, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f13, f3, f7 \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f13, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f14, f3, f7 \n\t" - "fmul.s f14, f14, %[RMAXREC] \n\t" - "fsw f14, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f14, %[FONE], f14 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f15, f3, f7 \n\t" - "fmul.s f15, f15, %[RMAXREC] \n\t" - "fsw f15, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f15, %[FONE], f15 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f16, f3, f7 \n\t" - "fmul.s f16, f16, %[RMAXREC] \n\t" - "fsw f16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f16, %[FONE], f16 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f17, f3, f7 \n\t" - "fmul.s f17, f17, %[RMAXREC] \n\t" - "fsw f17, (%[DST]) \n\t" - "addi %[DST], %[DST], -136 \n\t" - "fdiv.s f17, %[FONE], f17 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v18, v2, f11 \n\t" - "vfmul.vf v20, v4, f12 \n\t" - "vfmul.vf v22, v6, f13 \n\t" - "vfmul.vf v24, v8, f14 \n\t" - "vfmul.vf v26, v10, f15 \n\t" - "vfmul.vf v28, v12, f16 \n\t" - "vfmul.vf v30, v14, f17 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v18, v18 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v22, v22 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v26, v26 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vfcvt.x.f.v v30, v30 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v18, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v20, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v22, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v26, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v28, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v30, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vse32.v v16, (%[BUFFER]) \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m2 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer) - : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", - "f13", "f14", "f15", "f16", "f17"); - } else if (BlkLen == 32) { - __asm__ volatile( - "addi t3, zero, 32*4 \n\t" - "addi t2, zero, 32 \n\t" - - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 128 \n\t" - "addi a3, %[SRC], 256 \n\t" - "addi a4, %[SRC], 384 \n\t" - - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 36 \n\t" - "addi s3, %[DST], 72 \n\t" - "addi s4, %[DST], 108 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v4, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vle32.v v8, (a3) \n\t" - "addi a3, a3, 512 \n\t" - "vle32.v v12, (a4) \n\t" - "addi a4, a4, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v28, v12 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v20, v20, v22 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vfmax.vv v28, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v21, v20, v21 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfredmax.vs v29, v28, v29 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v21 \n\t" - "vfmv.f.s f12, v25 \n\t" - "vfmv.f.s f13, v29 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fsw f12, (s3) \n\t" - "addi s3, s3, 4 \n\t" - "fsw f13, (s4) \n\t" - "addi s4, s4, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v20, v4, f11 \n\t" - "vfmul.vf v24, v8, f12 \n\t" - "vfmul.vf v28, v12, f13 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vsetvli t0, t1, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 140 \n\t" - "vse8.v v20, (s2) \n\t" - "addi s2, s2, 140 \n\t" - "vse8.v v24, (s3) \n\t" - "addi s3, s3, 140 \n\t" - "vse8.v v28, (s4) \n\t" - "addi s4, s4, 140 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m4 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 128 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); - } else if (BlkLen == 64) { - __asm__ volatile( - "addi t3, zero, 64*2 \n\t" - "addi t2, zero, 64 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 68 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vfmax.vv v24, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v25 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 132 \n\t" - "vse8.v v24, (s2) \n\t" - "addi s2, s2, 132 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 256 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); - } else if (BlkLen == 128) { - __asm__ volatile( - "addi t2, zero, 128 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "sub %[K], %[K], t2 \n\t" - "QUANT%=: \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "vfmax.vv v28, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v30, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vfredmax.vs v31, v30, v31 \n\t" - "vfmv.f.s f10, v31 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v8, (a2) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "jal x0, QUANT%= \n\t" - "END%=: \n\t" - - : [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC) - : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); - } else { - float buffer[8] = { 0.0f }; - size_t cnt = BlkLen / 256; - - __asm__ volatile( - "slli t3, %[BLK], 2 \n\t" - "blt %[K], %[BLK], LOOP_TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_CMP%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v16, v16, v24 \n\t" - "vfmax.vv v0, v0, v16 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, LOOP_CMP%= \n\t" - "sub %[SRC], %[SRC], t3 \n\t" - "addi t6, %[CNT], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_QUANT%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "vse8.v v4, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bnez t6, LOOP_QUANT%= \n\t" - "sub %[K], %[K], %[BLK] \n\t" - "bge %[K], %[BLK], LOOP_MAIN%= \n\t" - "blez %[K], END%= \n\t" - "LOOP_TAIL%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[K], 0 \n\t" - "addi s1, %[SRC], 0 \n\t" - "TAIL_CMP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t6, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "sub t6, t6, t0 \n\t" - "vfabs.v v0, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, TAIL_CMP%= \n\t" - "addi t6, %[K], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[K], 0 \n\t" - "TAIL_QUANT%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t1, t6, e32, m8 \n\t" - "vle32.v v0, (s1) \n\t" - "addi s1, s1, 256 \n\t" - "sub t6, t6, t1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 64 \n\t" - "bnez t6, TAIL_QUANT%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer), - [CNT] "r"(cnt) - : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); - } + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" + + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" + + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); } } // namespace ime1 @@ -1451,1746 +584,444 @@ namespace { "vadd.vi v1, v1, -12 \n\t" template -void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { size_t LDC = ldc * sizeof(float); const size_t INNER = BlkLen / 16; float tmp[4 * 16]; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } -} -template -void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - size_t LDC = ldc * sizeof(float); - const size_t INNER = BlkLen / 16; - float tmp[4 * 16]; - - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } } template -void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { + GGML_UNUSED(ldc); size_t INNER = BlkLen / 16; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - // zp offset - "addi s7, %[B], 32 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s7, %[B], 32 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} -template -void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - const size_t INNER = BlkLen / 16; - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - // zp offset - "addi s7, %[B], 64 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - - // load scale - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - // load a scale - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - // a scale * b scale - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s7, %[B], 64 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s7, %[B], 32 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} - -template -inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias, ldc); - - } else if (scalestride == 2) { - SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl( - BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); - } -} -template -inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias); - } else if (scalestride == 2) { - SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } } } - } // namespace namespace ime1 { -size_t gemm_kernel_i8i4(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float * Bias, - const size_t ScaleStride) { - GGML_UNUSED(CountM); - GGML_UNUSED(CountK); - GGML_UNUSED(ldc); - if (CountM >= 4) { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { + if (quant_b_zp != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 4; } else { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + if (quant_b_zp != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 1; } } } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp new file mode 100644 index 00000000000..0c7a036a92a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp @@ -0,0 +1,5768 @@ +#include "ggml-impl.h" +#include "ggml.h" +#include "ime_kernels.h" +#include "rvv_kernels.h" +#include "string.h" + +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME2) +#else +# error "RISCV64_SPACEMIT_IME2 not defined" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels { +namespace ime2 { + +template +void gemm_kernel_i8i2k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + uint8_t * zeros16 = (uint8_t *) (quant_b_blk_data->zeros16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + _Float16 * zeros_fp16 = (_Float16 *) zeros16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + uint8_t * scales_temp = scales; + uint8_t * zps_temp = scales; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, zps_temp++) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + int16_t a_sum = a_sum_row[mi * 16 + kii]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0.0; + + uint8_t b_zp = zps_temp[ci * 16] >> 4; + uint8_t b_scale = scales_temp[ci] & 0x0F; + for (size_t bi = 0; bi < 16; bi++) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + uint8_t b0 = b_data_col[ci * 16 + bi]; + acc_0 += static_cast(a0) * static_cast((b0 >> b_shift) & 0x03); + } + + _Float16 scale_item = + static_cast<_Float16>(b_scale) * static_cast<_Float16>(factor_scale) * scales_fp16[ci]; + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + output[ci + mi * NB_COLS] += b_zp * a_sum * a_scale_row[mi] * zeros_fp16[ci]; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i3k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * b_hmask = quant_b_blk_data->hmask; + int8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + int8_t * scales_temp = scales; + uint16_t * b_mask_col = (uint16_t *) b_hmask; + + float acc_0_max = 0.0f; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, b_mask_col += NB_COLS) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0; + // blk 2 * kii + 0 + uint16_t b_shift_mask = 1; + for (size_t bi = 0; bi < 16; bi++, b_shift_mask <<= 1) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + int8_t b0 = static_cast((b_data_col[ci * 16 + bi] >> b_shift) & 0x03); + b0 -= b_mask_col[ci] & b_shift_mask ? 0 : 4; + acc_0 += static_cast(a0) * static_cast(b0); + } + + _Float16 scale_item = static_cast<_Float16>(scales_temp[ci]) * scales_fp16[ci] * + static_cast<_Float16>(factor_scale); + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t kblks_per_blk = 16; + GGML_ASSERT(k_blks % kblks_per_blk == 0); + + int64_t b_blk_stride = (sizeof(_Float16) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(_Float16); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + output_f16[ci + mi * NB_COLS] = static_cast<_Float16>(0.0f); + } + } + + size_t kii = 0; + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + _Float16 * b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + _Float16 a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_scale_fp16[ci]; + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output_f16[ci + mi * NB_COLS] += + static_cast(acc) * static_cast(a_scale) * static_cast(b_scale); + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } else { + kii++; + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_hp_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * MB_ROWS; + const size_t a_subblk_stride = q8_hp_blk_size(32, false, false) * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const uint8_t * b_tile_base = quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride) { + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + const uint8_t * b_superblk_ptr = b_tile_base + ki * b_superblk_stride; + const block_q4_0x32_layout * b_blocks = reinterpret_cast(b_superblk_ptr); + const uint8_t * b_zps = + quant_b_zp ? b_superblk_ptr + sizeof(block_q4_0x32_layout) * k_subblks_per_superblk : nullptr; + + _Float16 * a_sum_row = (_Float16 *) (a_data + a_subblk_stride * k_subblks_per_superblk); + _Float16 * a_scale_avg_row = (_Float16 *) (a_data + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + _Float16 scale_factor = a_scale_avg_row[0]; + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + int8_t * a_subblk = a_data + a_subblk_stride * ksi + MB_ROWS * sizeof(_Float16); + const _Float16 a_scale = a_scale_row[0]; + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + const uint8_t * b_qs = b_block.qs + ci * 16; + _Float16 b_scale = b_block.d[ci] * a_scale; + + int16_t acc = 0; + for (size_t bi = 0; bi < 16; bi++) { + uint8_t b = b_qs[bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + + acc += static_cast(a_subblk[mi * 32 + 2 * bi]) * static_cast(b0) + + static_cast(a_subblk[mi * 32 + 2 * bi + 1]) * static_cast(b1); + } + + const _Float16 scaled_acc = static_cast<_Float16>(acc) * b_scale; + output_f16[ci + mi * NB_COLS] += scaled_acc; + } + } + } + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + const uint8_t * b_zp_row = b_zps ? b_zps + ksi * NB_COLS : nullptr; + const _Float16 a_scale = a_scale_row[0]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + const _Float16 a_sum = a_sum_row[mi * k_subblks_per_superblk + ksi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_block.d[ci] * a_scale; + _Float16 a_sum_bzp = a_sum; + if (b_zp_row) { + a_sum_bzp = a_sum * static_cast<_Float16>(0.125f) * static_cast<_Float16>(b_zp_row[ci]); + } + + const _Float16 scaled_acc = a_sum_bzp * b_scale; + output[ci + mi * NB_COLS] += scaled_acc * scale_factor; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + auto val = static_cast(output_f16[ci + mi * NB_COLS]) * static_cast(scale_factor); + output[ci + mi * NB_COLS] += val; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void moe_gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] = c_ptr[mi]; + } + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = (a_data[mi])[2 * bi]; + int8_t a1 = (a_data[mi])[2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + (c_data[mi])[ci] = output[mi * NB_COLS + ci]; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] += NB_COLS; + } + } +} + +template +void moe_gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + // blk_len is expected to be 32 for Q5 types. + int64_t a_blk_stride = q8_blk_size(blk_len, true); + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + c_data[mi] = c_ptr[mi]; + } + + if (quant_b_zp) { + using blk_type = nrow_block_q5_1; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } else { + using blk_type = nrow_block_q5_0; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } +} + +template +void gemm_kernel_i8i8_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + blk_len); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + int8_t * b_data = (int8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len; bi++) { + int8_t a0 = a_data[mi * blk_len + bi]; + int8_t b0 = b_data[ci * blk_len + bi]; + acc += static_cast(a0) * static_cast(b0); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 for Q5 types + // quant_b_zp != nullptr => nrow_block_q5_1 (has zp) + // quant_b_zp == nullptr => nrow_block_q5_0 (no zp) + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + if (quant_b_zp) { + // nrow_block_q5_1: scales16[NB_COLS] + zp[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_1; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + // qh is packed as 4 bytes per column (32 bits for 32 elements) + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } else { + // nrow_block_q5_0: scales16[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_0; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + // Q5_0 has no zp, use default offset 16 (midpoint of 5-bit unsigned range) + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } +} + +template +void gemm_kernel_i8mxfp4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 (QK_MXFP4) + // quant_b_zp is unused for MXFP4 (symmetric quantization) + GGML_UNUSED(quant_b_zp); + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + using blk_type = nrow_block_mxfp4; + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = GGML_E8M0_TO_FP32_HALF(quant_b_blk_data->e[ci]); + + // Read 32 sign bits for this column + uint32_t sign_bits; + memcpy(&sign_bits, &quant_b_blk_data->qh[ci * 4], 4); + + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + + // qs[ci*16 + bi] stores abs(vals[bi*2]) in low 4 bits + // and abs(vals[bi*2+1]) in high 4 bits + uint8_t qs_byte = quant_b_blk_data->qs[ci * 16 + bi]; + int8_t b_abs0 = static_cast(qs_byte & 0x0F); + int8_t b_abs1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract sign bits: bit (2*bi) for vals[2*bi], bit (2*bi+1) for vals[2*bi+1] + int8_t b0 = (sign_bits >> (2 * bi)) & 1 ? -b_abs0 : b_abs0; + int8_t b1 = (sign_bits >> (2 * bi + 1)) & 1 ? -b_abs1 : b_abs1; + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +void gemm_kernel_i8i2k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "addi %[A], %[A], 4 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 32 \n\t" // A data addr + "addi s3, %[B], 0 \n\t" + "vxor.vv v30, v29, v29 \n\t" // tmp result + + "INNER_K_LOOP%=: \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vxor.vv v2, v2, v2 \n\t" + "vxor.vv v3, v3, v3 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + + // A data, 1x64@i8 + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v4, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v5, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v6, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetvli t0, x0, e64, mf2 \n\t" + "vslideup.vi v3, v4, 2 \n\t" + "vslideup.vi v28, v5, 4 \n\t" + "vslideup.vi v29, v6, 6 \n\t" + + // init the accumu to zero + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v18, v18 \n\t" + "vxor.vv v22, v18, v18 \n\t" + "vxor.vv v24, v18, v18 \n\t" + "vxor.vv v26, v18, v18 \n\t" + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu v20, v2, v8, i8 \n\t" + "vmadotsu v22, v2, v12, i8 \n\t" + "vmadotsu v24, v2, v16, i8 \n\t" + "vmadotsu v26, v2, v4, i8 \n\t" + + "vmadotsu v20, v3, v9, i8 \n\t" + "vmadotsu v22, v3, v13, i8 \n\t" + "vmadotsu v24, v3, v17, i8 \n\t" + "vmadotsu v26, v3, v5, i8 \n\t" + + "vmadotsu v20, v28, v10, i8 \n\t" + "vmadotsu v22, v28, v14, i8 \n\t" + "vmadotsu v24, v28, v18, i8 \n\t" + "vmadotsu v26, v28, v6, i8 \n\t" + + "vmadotsu v20, v29, v11, i8 \n\t" + "vmadotsu v22, v29, v15, i8 \n\t" + "vmadotsu v24, v29, v19, i8 \n\t" + "vmadotsu v26, v29, v7, i8 \n\t" + + "vand.vi v10, v0, 0xf \n\t" // scale + "vwadd.vx v12, v10, x0 \n\t" + "vsetvli t0, x0, e16, m2 \n\t" + "vwadd.vx v16, v12, x0 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v6, v2, v4, 3 \n\t" // 0,1 + "vpack.vv v8, v3, v5, 3 \n\t" // 2,3 + + // mul scale + "vmacc.vv v30, v6, v16 \n\t" + "vmacc.vv v30, v7, v17 \n\t" + "vmacc.vv v30, v8, v18 \n\t" + "vmacc.vv v30, v9, v19 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + // load zp B + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v4, (s3) \n\t" + "vsrl.vi v8, v4, 4 \n\t" // zp + + // asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + + "vsetvli t0, x0, e16, mf4 \n\t" + "vle16.v v2, (%[A]) \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vnsrl.wi v12, v2, 0 \n\t" // low 8 + "vnsra.wi v13, v2, 8 \n\t" // high 8 + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v20, v13, v8, i8 \n\t" + "vmadotsu v22, v13, v9, i8 \n\t" + "vmadotsu v24, v13, v10, i8 \n\t" + "vmadotsu v26, v13, v11, i8 \n\t" + + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vsll.vi v26, v26, 8 \n\t" + + "vmadotu v20, v12, v8, i8 \n\t" + "vmadotu v22, v12, v9, i8 \n\t" + "vmadotu v24, v12, v10, i8 \n\t" + "vmadotu v26, v12, v11, i8 \n\t" + + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v28, v2, v4, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v0, (t2) \n\t" // scale16 + "addi t2, t2, 64 \n\t" + "vle16.v v1, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v2, v0 \n\t" + "vfwcvt.f.f.v v4, v1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v30, v30 \n\t" + "vfcvt.f.x.v v28, v28 \n\t" + "addi %[B], t2, 64 \n\t" + "mv %[A], t3 \n\t" + + "vfmul.vv v30, v30, v2 \n\t" // mul scale16 + "vfmacc.vv v30, v28, v4 \n\t" // + mul zero16 + "vfmacc.vf v31, fa0, v30 \n\t" + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3"); + } +} + +void gemm_kernel_i8i2k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + _Float16 scale = 0.0625f; + _Float16 scale_1 = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v31, v31 \n\t" // init result + "vxor.vv v29, v31, v31 \n\t" + "vxor.vv v30, v31, v31 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 128 \n\t" // A data addr + "addi s4, t2, 1024 \n\t" // scale16 addr + "addi s4, s4, 1024 \n\t" // TODO + "addi s3, %[B], 0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s4) \n\t" // load scale16 + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v22, v1, v1, 3 \n\t" + + "addi s4, t3, 256 \n\t" // addr 1 + "addi s5, t3, 512 \n\t" // addr 2 + "addi s6, t3, 768 \n\t" // addr 3 + + // init the accu to 0 + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v25, v25, v25 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v27, v27, v27 \n\t" + + "INNER_K_LOOP%=: \n\t" + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + "vand.vi v1, v1, 0xf \n\t" + + "vfwcvt.f.x.v v20, v1 \n\t" // f16 scale B + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v0, v20, v22 \n\t" // mul scale16 + "vfmul.vv v1, v21, v22 \n\t" // mul scale16 + "vfmul.vf v0, v0, %[SCALE] \n\t" // mul magic + "vfmul.vf v1, v1, %[SCALE] \n\t" // mul magic + + // A data, 4x64@i8 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 64 \n\t" + "vle8.v v3, (s4) \n\t" + "addi s4, s4, 64 \n\t" + "vle8.v v4, (s5) \n\t" + "addi s5, s5, 64 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 64 \n\t" + + // 4x64 => 4x16x4 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v6, v2, v3, 1 \n\t" + "vpack.vv v8, v4, v5, 1 \n\t" + "vpack.vv v2, v6, v8, 2 \n\t" // 0, 2 + + "vpack.vv v20, v2, v2, 3 \n\t" // 1 + "vor.vv v23, v21, v21 \n\t" + "vpack.vv v20, v3, v3, 3 \n\t" // 3 + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu.hp v24, v2, v8, v0, 0, i8 \n\t" + "vmadotsu.hp v25, v2, v12, v0, 1, i8 \n\t" + "vmadotsu.hp v26, v2, v16, v0, 2, i8 \n\t" + "vmadotsu.hp v27, v2, v4, v0, 3, i8 \n\t" + + "vmadotsu.hp v24, v23, v9, v0, 4, i8 \n\t" + "vmadotsu.hp v25, v23, v13, v0, 5, i8\n\t" + "vmadotsu.hp v26, v23, v17, v0, 6, i8\n\t" + "vmadotsu.hp v27, v23, v5, v0, 7, i8 \n\t" + + "vmadotsu.hp v24, v3, v10, v1, 0, i8 \n\t" + "vmadotsu.hp v25, v3, v14, v1, 1, i8 \n\t" + "vmadotsu.hp v26, v3, v18, v1, 2, i8 \n\t" + "vmadotsu.hp v27, v3, v6, v1, 3, i8 \n\t" + + "vmadotsu.hp v24, v21, v11, v1, 4, i8\n\t" + "vmadotsu.hp v25, v21, v15, v1, 5, i8\n\t" + "vmadotsu.hp v26, v21, v19, v1, 6, i8\n\t" + "vmadotsu.hp v27, v21, v7, v1, 7, i8 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v2, v24, v25, 1 \n\t" + "vpack.vv v4, v26, v27, 1 \n\t" + "vpack.vv v6, v2, v4, 2 \n\t" // 0,1,2,3 + + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + // load zp B, 16x8x4@int4 + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v0, (s3) \n\t" + "vsrl.vi v0, v0, 4 \n\t" // zp + + // 4x16@int16 + "vsetvli t0, x0, e16, m1 \n\t" // a sum + "vle16.v v12, (%[A]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnsrl.wi v10, v12, 0 \n\t" // low 8 + "vnsra.wi v11, v12, 8 \n\t" // high 8 + + // asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v18, v11, v0, i8 \n\t" + "vmadotsu v20, v11, v1, i8 \n\t" + "vmadotsu v22, v11, v2, i8 \n\t" + "vmadotsu v24, v11, v3, i8 \n\t" + "vsll.vi v18, v18, 8 \n\t" + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vmadotu v18, v10, v0, i8 \n\t" + "vmadotu v20, v10, v1, i8 \n\t" + "vmadotu v22, v10, v2, i8 \n\t" + "vmadotu v24, v10, v3, i8 \n\t" + + "vpack.vv v10, v18, v20, 2 \n\t" + "vpack.vv v12, v22, v24, 2 \n\t" + "vpack.vv v14, v10, v12, 3 \n\t" + "vpack.vv v16, v11, v13, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t2, t2, 64 \n\t" + "vle16.v v20, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v22, v20 \n\t" + + // mul 1/magic + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmul.vf v0, v6, %[SCALE_1] \n\t" + "vfwmul.vf v2, v7, %[SCALE_1] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v14, v14 \n\t" + "vfcvt.f.x.v v15, v15 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + + "addi %[B], t2, 64 \n\t" + "mv %[A], s6 \n\t" + + "vfmacc.vv v0, v14, v22 \n\t" // + mul zero16 + "vfmacc.vv v1, v15, v22 \n\t" + "vfmacc.vv v2, v16, v22 \n\t" + "vfmacc.vv v3, v17, v22 \n\t" + + "vfmacc.vf v28, fa0, v0 \n\t" // mul a scale + "vfmacc.vf v29, fa1, v1 \n\t" + "vfmacc.vf v30, fa2, v2 \n\t" + "vfmacc.vf v31, fa3, v3 \n\t" + + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "vse32.v v29, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v30, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v31, (t1) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks), [LDC] "r"(ldc * 4), [SCALE] "f"(scale), [SCALE_1] "f"(scale_1) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i3k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; //only support 32 in ASM + using blk_type = nrow_block_q3_k; + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride; + int64_t b_ncol_block_stride = sizeof(blk_type); + + // Constants used by q3_k scaling in HP branch: + // - k_q3k_scale_step: per-nibble scale factor (1/16). + // - k_a_scale_post_mul: A_scale needs an extra *16 at the end (pairs with 1/16 above). + const _Float16 k_q3k_scale_step = (_Float16) 0.0625f; // 1 / 16 + const float k_a_scale_post_mul = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; +#if 0 + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1 32bit + // Asum int16 * 16 256bit + // A M1K256 int8 2048bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+32 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" // clear acc + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + + // ordinary vmadot: vle*10 vecIns*78 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // K0-15 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load B qs chunk (128B per K16, 16 times => 2048B) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (s3) \n\t" + "addi s3, s3, 64 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K16-31 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // load A scale (fp32) and advance A to next superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // t4 = next B blk base + "addi s3, s2, 4+32 \n\t" + + // load B scales16[32] (fp16) at end of qs region + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v2, (s6) \n\t" + + // pointer modify + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v30 \n\t" + "vfmul.vf v1, v24, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v31, v1, v26 \n\t" + + // next K-superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "s2", "s3", "s4", "s5", "s6", "s7"); +#else + + __asm__ volatile( + // ========================= + // Kernel overview (M1 x N32) + // ========================= + // Process one output row (M=1) and 32 columns (N=32) per call. + // + // Loop structure: + // - Outer loop: K superblocks of size K=256 (k_blks times) + // - Each K256 superblock is broken into 4 x K64 + // - Each K64 is processed as 4 x K16 "sub-blocks" (via unpack+dot) + // + // Data layout (high level): + // A (q8k K=256, per superblock): + // [ fp32 a_scale ][ int16 a_sum[16] ][ int8 a_qs[256] ] + // B (nrow_block_q3_k<32>, per superblock): + // [ int8 scales[32*16] ][ hmask[1024] ][ qs[2048] ][ fp16 scales16[32] ] + // + // Registers/pointers: + // s2: pA (points at A superblock header; used to load fp32 a_scale) + // s3: pA_qs (points at A int8 data within the current superblock) + // s4: pB_scales (points at B int8 per-K16 scales) + // s5: pB_hmask (points at B sign mask area) + // s6: pB_qs (points at B 2-bit packed qs area) + // s8: pB_scales16 (points at B fp16 scales16[32] at the end of block) + // s7: pB_base (base pointer to current B block; used for block-to-block stride) + + // t2 = number of K256 superblocks + "mv t2, %[KBLKS] \n\t" + // t3 = number of K64 chunks per K256 superblock (256 / 64) + "li t3, 4 \n\t" + + // A pointers + "mv s2, %[pA] \n\t" // s2 = pA_superblock (a_scale at +0) + "addi s3, %[pA], 4+32 \n\t" // s3 = pA_qs (skip a_scale + a_sum[16]) + + // B pointers for nrow_block_q3_k<32> + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales[32*16]) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + // scales16 is at the end of the block: qs(2048) after hmask + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" // s8 = pB_scales16 (fp16 scales16[32]) + "mv s7, %[pB] \n\t" // s7 = pB_base (for next-block address calc) + + // v31: final FP32 accumulator for N=32 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ---- Preload B scales16[32] and build FP16 scale vector used by vmadot.hp ---- + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" // load fp16 scales16[32] + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" // broadcast/pack to match lanes + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" // v30 = scales16 * (1/16) + + // v24-v27: fp16 partial accumulators for a K64 chunk (vmadot.hp outputs) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + // HP vmadot: vle*10 vecIns*38 vmadot.hp*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" // loop over K256 superblocks + "K64_LPST%=: \n\t" // loop over 4 x K64 chunks + + // ------------------------------------------------------------ + // K0-15: load B scales + {hmask, qs} + A data; unpack and dot + // ------------------------------------------------------------ + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" // B int8 scales for this K16 + "addi s4, s4, 128 \n\t" + + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" // B hmask for this K16 + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v3, (s3) \n\t" // A int8 data for this K16 + "addi s3, s3, 64 \n\t" + + // Convert B int8 scales to FP16 and apply scales16*(1/16) + "vsetvli t0, x0, e8, m1 \n\t" + "vfwcvt.f.x.v v28, v2 \n\t" // int8 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v1, v28, v30 \n\t" // v1: FP16 scale vector for vmadot.hp + "vfmul.vv v29, v29, v30 \n\t" + + // Unpack B 2-bit qs + hmask -> signed int8 in v12..v15 + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + // (Next K16 unpack path uses a fresh hmask load) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Prepare another group from packed qs (bit shifts) + apply sign from hmask + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vsrl.vi v16, v8, 6 \n\t" + "vsrl.vi v17, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v18, v10, 6 \n\t" + "vsrl.vi v19, v11, 6 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v16, v16, -4, v0.t \n\t" + + // A shift for the second dot within this K64 + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + // Dot products with FP16 scaling (accumulate into v24..v27) + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v12, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v13, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v14, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v15, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v16, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v17, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v18, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v19, v1, 7, i8 \n\t" + + // (K32-47 / K48-63 blocks continue unchanged...) + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vmv.v.v v1, v29 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v3, v3, 4 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + + "vsrl.vi v20, v8, 6 \n\t" + "vsrl.vi v21, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v22, v10, 6 \n\t" + "vsrl.vi v23, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v20, v20, -4, v0.t \n\t" + + // K48-63 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v8, v4, 6 \n\t" + "vsrl.vi v9, v5, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v10, v6, 6 \n\t" + "vsrl.vi v11, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v8, v8, -4, v0.t \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v20, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v21, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v22, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v23, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v8, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v9, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v10, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v11, v1, 7, i8 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ---- End of K64 chunk: reduce fp16 accumulators -> fp32 and scale by A ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v24, v25, 1 \n\t" + "vpack.vv v14, v26, v27, 1 \n\t" + "vpack.vv v16, v12, v14, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v26, v16 \n\t" // fp16 -> fp32 vector (qsum * b_scales) + + // Load A scale and advance A pointer to next K256 superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // next B block base + "addi s3, s2, 4+32 \n\t" // reset A data pointer for next block + + // Advance B pointers to next K256 superblock + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" + "addi s7, t4, 0 \n\t" + "addi t2, t2, -1 \n\t" + + // Final per-block scaling: a_scale * 16.0f + "fmul.s f0, f0, %[a_post_mul] \n\t" + // acc += (qsum * b_scales) * (a_scale*16) + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vf v31, f0, v26 \n\t" + + "beqz t2, BLK_LPND%= \n\t" + + // Preload next block's scales16 and rebuild v30 for vmadot.hp + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" + + // Reset fp16 partial accumulators for next K64 loop(s) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [q3_step] "f"(k_q3k_scale_step), + [a_post_mul] "f"(k_a_scale_post_mul) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"); +#endif + } +} + +void gemm_kernel_i8i3k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q3_k<32>; + constexpr size_t NB_COLS = 32; //only support 32 in ASM + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t b_ncol_block_stride = sizeof(blk_type); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; + + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1* 4row 128bit + // Asum int16 * 16 4row 1024bit + // A M1K256 int8 4row 8192bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 16+128 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // v24-v27: K256 temp accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "vxor.vv v28, v0, v0 \n\t" // v28-v31: final accumulator + "vxor.vv v29, v0, v0 \n\t" + "vxor.vv v30, v0, v0 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ordinary vmadot: vle*13 vecIns*96 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // ========== K0-15: First K16 sub-block ========== + // Load B INT8 scale factors (32 cols × 16 K16 blocks) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v8, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // Load B quantized data (32 cols × 16 elements × 2bit, stored in 4 groups) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // Load B hmask (32 cols × 16bit sign mask) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Load A data (4 rows × 16 elements × INT8) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v12, (s3) \n\t" + "addi s3, s3, 256 \n\t" // Jump to next row + "vle8.v v13, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v14, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v15, (s3) \n\t" + "addi s3, s3, -768+64 \n\t" // Back to first row, advance 16 elements + + // Pack A data: merge 4 rows into 2 vectors + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v13, 1 \n\t" + "vpack.vv v18, v14, v15, 1 \n\t" + "vpack.vv v2, v16, v18, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" // 4 rows × cols 0-7 + "vmadot v18, v2, v13, i8 \n\t" // 4 rows × cols 8-15 + "vmadot v20, v2, v14, i8 \n\t" // 4 rows × cols 16-23 + "vmadot v22, v2, v15, i8 \n\t" // 4 rows × cols 24-31 + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" // Merge cols 0-15 + "vpack.vv v14, v20, v22, 2 \n\t" // Merge cols 16-31 + "vpack.vv v16, v12, v14, 3 \n\t" // Inter-row results (INT16) + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // INT8 → INT16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // INT16 → INT32 + + // Accumulate to K256 accumulator: qsum * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" // Row 0 + "vmacc.vv v25, v17, v23 \n\t" // Row 1 + "vmacc.vv v26, v18, v23 \n\t" // Row 2 + "vmacc.vv v27, v19, v23 \n\t" + + // ========== K16-31, K32-47, K48-63: Similar processing ========== + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 8 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 4 \n\t" + "vsll.vi v13, v5, 4 \n\t" + "vsll.vi v14, v6, 4 \n\t" + "vsll.vi v15, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" + "vmadot v18, v2, v13, i8 \n\t" + "vmadot v20, v2, v14, i8 \n\t" + "vmadot v22, v2, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 2 \n\t" + "vsll.vi v13, v5, 2 \n\t" + "vsll.vi v14, v6, 2 \n\t" + "vsll.vi v15, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v3, v3, 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ========== K256 superblock complete, apply scale factors ========== + // Load A's 4 row scale factors (FP32) + "flw f0, (s2) \n\t" + "flw f1, 4(s2) \n\t" + "flw f2, 8(s2) \n\t" + "flw f3, 12(s2) \n\t" + "add s2, s2, %[A_STR] \n\t" // Advance to next superblock + "add t4, s7, %[B_STR] \n\t" // t4 = next B block address + "addi s3, s2, (4+32)*4 \n\t" + + // Load B FP16 global scale factors (32 cols) + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (s6) \n\t" + + // Update B pointers to next block + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // ========== Type conversion and final scaling ========== + // FP16 → FP32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v9, v8 \n\t" + + // INT32 → FP32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + + // Compute a_scale * b_scale (4 rows) + "vfmul.vf v12, v9, f0 \n\t" + "vfmul.vf v13, v9, f1 \n\t" + "vfmul.vf v14, v9, f2 \n\t" + "vfmul.vf v15, v9, f3 \n\t" + + // Final accumulation: result += qsum * a_scale * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vv v28, v12, v24 \n\t" + "vfmacc.vv v29, v13, v25 \n\t" + "vfmacc.vv v30, v14, v26 \n\t" + "vfmacc.vv v31, v15, v27 \n\t" + + // Prepare for next K superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // Clear K256 accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + + // ========== Store results (4 rows × 32 cols) ========== + "mv t5, %[pC] \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v28, (%[pC]) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v29, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v30, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v31, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [A_STR] "r"(a_nrow_block_stride), [LDC] "r"(ldc * 4) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "f2", "f3", "s2", "s3", "s4", "s5", "s6", "s7"); + } +} + +void gemm_kernel_i8i4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... +#if 0 + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*7 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" // Bzp u8 -> u16 + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vfcvt.f.x.v v16, v26 \n\t" // zp i16 -> fp16 + "vadd.vi v18, v16, 0 \n\t" + "vadd.vi v20, v16, 0 \n\t" + "vadd.vi v22, v16, 0 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + // mac result * b_scale; f16*f16->f32 + "vfwmul.vv v31, v30, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + +#endif + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(uint8_t) + // b zp + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // Bzp uint8_t * N32 256bit; + // || scl*32..(fp16) | zp*32(uint8) | blk0 blk1 ... blk31 || scl*32..(fp16) ... + // || Element || Element ... + + //bias always be nullptr +#if 0 + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*6 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v31, (s4) \n\t" // B zp 32Row*uint8 = 256bit + "addi s4, s4, 32+128*4 \n\t" + + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v31, x0 \n\t" // Bzp u8 -> u16 + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vfwcvt.f.f.v v22, v30 \n\t" // b_scale fp16 -> fp32 + "vfcvt.f.x.v v18, v26 \n\t" // zp i16 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfwadd.vv v20, v18, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // mac result * b_scale; f32*f32->f32 + "vfmul.vv v31, v22, v20 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#endif + } + } +} + +void gemm_kernel_i8i4_hp_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" // init acc to zero + "mv t4, %[BK] \n\t" + "li t0, 0x4c00 \n\t" // 16 in fp16 + "fmv.h.x fa0, t0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + "li t5, 8 \n\t" + "addi t6, %[A], 288 \n\t" // point to blk scale + "flh ft1, (t6) \n\t" + "addi t6, %[A], 272 \n\t" // point to asum + + // init the acc fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v16, v18, v18 \n\t" + "vxor.vv v17, v18, v18 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v18, v18 \n\t" + + "INNER_BLK_LOOP%=: \n\t" + // load a sum and scale + "flh fa1, (t6) \n\t" + "addi t6, t6, 2 \n\t" + "flh ft0, (%[A]) \n\t" + "addi %[A], %[A], 2 \n\t" + // load A + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (%[A]) \n\t" // 1x32@i8 + "addi %[A], %[A], 32 \n\t" + + // load scale B and B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (%[B]) \n\t" // b_scale fp16 + "addi %[B], %[B], 64 \n\t" + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vfmul.vf v8, v8, ft0 \n\t" // scale b * scale a + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa1 \n\t" // scale b * scale a * asm + "vfwmacc.vf v31, ft1, v10 \n\t" // asum * scale a * scale b * blk scale + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v3, v28, v28, 3 \n\t" // hi4 of A + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v16, v3, v4, v0, 4, i4 \n\t" // high 4 + "vmadotsu.hp v17, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v18, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v19, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v16, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v17, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v18, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v19, v2, v7, v0, 3, i4 \n\t" + + "addi t5, t5, -1 \n\t" + "bgtz t5, INNER_BLK_LOOP%= \n\t" + + "vpack.vv v8, v16, v17, 1 \n\t" + "vpack.vv v12, v18, v19, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t4, t4, -1 \n\t" + "vfwmacc.vf v31, ft1, v20 \n\t" + //"vsetvli t0, x0, e32, m1 \n\t" + //"vfmul.vf v31, v31, ft1 \n\t" // blk scale + + // update A ptr + "addi %[A], t6, 2 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "ft0", "ft1"); + } + } else { + // TODO: support quant_b_zp for i8i4 hp kernel + GGML_ABORT("gemm_kernel_i8i4_hp_m1 with quant_b_zp is not supported yet"); + } +} + +void gemm_kernel_i8i4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; +#if 0 + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vwmul.vx v10, v8, t1 \n\t" // 8*asum + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vrgather.vi v0, v10, 0 \n\t" + "vrgather.vi v1, v10, 1 \n\t" + "vrgather.vi v2, v10, 2 \n\t" + "vrgather.vi v3, v10, 3 \n\t" + + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v1 \n\t" + "vadd.vv v18, v18, v2 \n\t" + "vadd.vv v19, v19, v3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc*4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#else + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vsll.vi v8, v8, 3 \n\t" // asum * 8 + "vfcvt.f.x.v v9, v8 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vrgather.vi v10, v9, 0 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v16, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v17, v16, 4 \n\t" + "vnpack4.vv v12, v16, v17, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v16, v10, v10,0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v20, v16, v16,0 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vpack.vv v18, v20, v20, 0 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + "vfwmul.vv v18, v21, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#endif + } + } else { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2\n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + // load a sum + "lh s1, (%[A]) \n\t" + "lh s2, 2(%[A]) \n\t" + "lh s3, 4(%[A]) \n\t" + "lh s4, 6(%[A]) \n\t" + "addi %[A], %[A], 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v0, v10, s1 \n\t" + "vwmul.vx v2, v10, s2 \n\t" + "vwmul.vx v4, v10, s3 \n\t" + "vwmul.vx v6, v10, s4 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v2 \n\t" + "vadd.vv v18, v18, v4 \n\t" + "vadd.vv v19, v19, v6 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST]\n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3", "s1", "s2", "s3", "s4"); + } + } +} + +void gemm_kernel_i8i4_hp_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_SUBBLKS_PER_SUPERBLK = 8; + constexpr size_t K_SUBBLK_LEN = 32; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + GGML_ASSERT(count_m >= 4); + + // Contract: + // - computes a 4-row x 32-col tile per inner invocation + // - A is q8 HP packed in m4 layout, one logical K256 block at a time + // - B is q4 HP packed in N32 tiles, optionally with a separate zp area + // - tail-N is currently not handled here; the caller must provide full N32 tiles + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * K_SUBBLKS_PER_SUPERBLK + + (quant_b_zp ? NB_COLS * K_SUBBLKS_PER_SUPERBLK * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * 4; + const size_t a_subblk_stride = q8_hp_blk_size(K_SUBBLK_LEN, false, false) * 4; + + if (quant_b_zp != nullptr) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the with-zp path. + // + // A: M4 x K256 q8 HP block + // - split into 8 x K32 subblocks + // - each K32 subblock is 136B: + // 8B = 4 x fp16 row scales + // 128B = 4 x int8[32] row payloads + // - trailer after 8 subblocks is 72B: + // 4 rows x fp16[8] a_sum values, indexed as [row][ksi] + // 4 rows x fp16 scale_avg tail + // + // B: N32 x K256 q4 HP block with explicit zp area + // - each K32 subblock is 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload for 32 columns x 32 k-elements + // - zp is stored separately, not interleaved with the 576B payload block + // - one K256 superblock is laid out as: + // 8 x (scale + qs) blocks = 4608B + // 8 x zp[32] = 256B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // row1/row2/row3 are at +16/+32/+48 bytes + // - s5: current B (scale + qs) K32 subblock base + // - s6: current B zp[32] base for this ksi + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576, B_zp += 32 + // - per ki : skip the 72B A trailer and advance B to the next 4864B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + const _Float16 hp_scale_0125 = (_Float16) 0.125f; + + // VPR grouping used below: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : zp u8 / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "li t1, 4608 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + // load Bzp[32] for the current ksi from the dedicated zp area + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (s6) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + "vfwcvt.f.xu.v v10, v8 \n\t" // uint8 -> fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + "fmul.h ft1, ft1, %[HP0125] \n\t" + "fmul.h ft2, ft2, %[HP0125] \n\t" + "fmul.h ft3, ft3, %[HP0125] \n\t" + "fmul.h ft4, ft4, %[HP0125] \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> mf2 + "vfmul.vv v10, v10, v16 \n\t" // zp * ascale * bscale; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v10, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v10, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v10, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v10, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "addi s6, s6, 32 \n\t" // next zp[32] + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + "mv s5, s6 \n\t" // s6 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), + [HP1] "f"(hp_scale_1), [HP0125] "f"(hp_scale_0125) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s5", "s6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v10", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", + "memory"); + } + return; + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the no-zp path. + // + // A layout is identical to the with-zp branch. + // + // B: N32 x K256 q4 HP block without explicit zp storage + // - each K32 subblock is still 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload + // - zp is implicit and treated as a constant value 8 in the kernel + // - one K256 superblock therefore contains only: + // 8 x (scale + qs) blocks = 4608B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // - s5: current B (scale + qs) K32 subblock base + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576 + // - per ki : skip the 72B A trailer and advance B to the next 4608B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + + // VPR grouping used below matches the with-zp path: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : implicit zp lane / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v16, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v16, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v16, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v16, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" //N32in1register + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + // s5 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), [HP1] "f"(hp_scale_1) + : "t0", "t2", "t3", "t4", "t5", "t6", "s5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v10", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", "memory"); + } + return; + } +} + +void gemm_kernel_i8mxfp4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 1); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // MXFP4 no-zp: per column per k-block stride = scale_e8m0(1B) + qs(16B) + qh(4B) = 21B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // qh sign/high-bit mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // qs packed 4-bit magnitudes: n×k_blks×16 + n * k_blks * sizeof(uint8_t); // scale: n×k_blks×1 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (q8 block with per-block scale and stored sum field): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (u8×32) + // s5 = pB qh (sign/high-bit mask, 128B) + // s6 = pB qs (packed 4-bit magnitudes, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales u8 (loaded as bytes; later widened) + // v0 = qh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = qs unpack/pack temporaries (build signed vmadot lanes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32 \n\t" // s5 = pBh (pB + 32B scale) + "addi s6, %[pB], 32+128 \n\t" // s6 = pBs (pB + 32 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 672B = 32(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load qs (512B = 4 VRF) from s6, advance s6 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = qs N32K32 packed 4-bit magnitudes + "addi s6, s6, 128*4+128+32 \n\t" // s6 += 672 (512+128+32) + + // ---- load B scale (32B = 32×u8) from s4, advance s4 by 672 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_u8 × 32 + "addi s4, s4, 32+128*4+128 \n\t" // s4 += 672 (32+512+128) + + // ---- load qh (128B = 1 VRF) from s5, advance s5 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = qh N32K32 sign/high-bit packed + "addi s5, s5, 128+32+128*4 \n\t" // s5 += 672 (128+32+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + // ---- Decode packed MXFP4 payload into a vmadot-friendly signed-lane layout ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to 0 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- int8 dot products over the decoded MXFP4 lane groups ---- + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v2, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v2, 1 \n\t" + "vadd.vi v28, v2, -1 \n\t" + "vsll.vi v28, v28, 23 \n\t" + "vsll.vv v28, v30, v2, v0.t \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8mxfp4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 4); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + // MXFP4 block layout per K32/N32 tile: + // [scale_e8m0 x 32][qh sign/high-bit mask x 128B][qs packed 4-bit magnitudes x 512B] + // There is no explicit zp stream; qh is combined with qs to reconstruct signed MXFP4 values. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 32 \n\t" + "addi t2, %[B], 160 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "addi %[B], %[B], 672 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + + // Decode the packed MXFP4 payload once for the whole tile and expand it + // into a vmadot-friendly layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v26, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v26, 1 \n\t" + "vadd.vi v24, v26, -1 \n\t" + "vsll.vi v18, v24, 23 \n\t" + "vsll.vv v18, v30, v26, v0.t \n\t" + + // Row 0: dot(A0, decoded MXFP4 lane groups), accumulate in int32 and + // then apply A/B scaling. + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i5_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // ========================================================================= + // i8i5: 8-bit activation × 5-bit weight (4-bit low + 1-bit high mask) + // + // B layout per N32K32 k-block (no-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 191] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [192.. 703] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 704B per k-block stride + // + // B layout per N32K32 k-block (with-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 95 ] : zp_uint8 × 32 (32B) + // [96 .. 223] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [224.. 735] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 736B per k-block stride + // + // Bh format per N8K32 sub-block (32B): + // K rows × N cols × 1bit packed as bytes (8 cols per byte, K groups of 4B) + // Byte k gives 8 mask bits for columns N7..N0 at k-th K-element. + // + // Computation: + // B5bit_signed = (Bs | (Bh << 4)) - zp + // dot(A, B5) = dot(A, Bs_u4) + 16*dot(A, Bh_u1) - zp*asum + // No-zp: implicit zp = 16 (unsigned [0..31] centered at 16) + // With-zp: explicit zp from data + // + // ========================================================================= + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 no-zp: per column per k-block stride = fp16(2B) + i4(16B) + i1(4B) = 22B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // Bh i1 mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // Bs i4 data: n×k_blks×16 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32) + // s5 = pB Bh (i1 mask, 128B) + // s6 = pB Bs (i4 packed, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBh (pB + 64B scale) + "addi s6, %[pB], 32*2+128 \n\t" // s6 = pBs (pB + 64 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 704B = 64(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+64 \n\t" // s6 += 704 (512+128+64) + + // ---- load B scale (64B = 32×fp16) from s4, advance s4 by 704 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64+128*4+128 \n\t" // s4 += 704 (64+512+128) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+64+128*4 \n\t" // s5 += 704 (128+64+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + //// vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 with-zp: per column per k-block stride = fp16(2B)+zp(1B)+i4(16B)+i1(4B)=23B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // Bs i4: n×k_blks×16 + n * k_blks * (blk_len / 8) + // Bh i1: n×k_blks×4 + n * k_blks * sizeof(uint8_t) + // zp: n×k_blks×1 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32); 每个 k-block 先 +64 指向 zp,再 +672 到下一个 block + // s5 = pB Bh (i1 mask, 128B) (offset +96) + // s6 = pB Bs (i4 packed, 512B) (offset +224) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) / later reused to hold Bzp bytes + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBh (pB + 64B scale + 32B zp = pB+96) + "addi s6, %[pB], 32*3+128 \n\t" // s6 = pBs (pB + 96 + 128 = pB+224) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 736B = 64(scale) + 32(zp) + 128(Bh) + 512(Bs) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+96 \n\t" // s6 += 736 (512+128+96) + + // ---- load B scale (64B = 32×fp16) from s4; then s4 points to zp[32] ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64 \n\t" // s4 += 64 (now points to zp) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+96+128*4 \n\t" // s5 += 736 (128+96+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4+128 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8i5_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + + GGML_UNUSED(count_m); + GGML_UNUSED(blk_len); + + // This kernel computes a 4x32 output tile. For each K32 block we decode the + // packed Q5 weights once and reuse the decoded vectors across the 4 A rows. + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + if (quant_b_zp) { + // Q5_1 block layout per K32/N32 tile: + // [scale_fp16 x 32][zp_u8 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> zp stream, t2 -> qh bitmask stream, s5 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 96 \n\t" + "addi s5, %[B], 224 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t2) \n\t" + "vl4r.v v8, (s5) \n\t" + "addi %[B], %[B], 736 \n\t" + + // Decode Q5 payload once for the whole tile: + // 1) split `qs` low/high nibbles, + // 2) repack into bytes, + // 3) use the `qh` mask to inject bit4 (+16) where needed, + // 4) expand into the vmadot-friendly layout reused by all 4 rows. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v3, (t1) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * zp, then scale by A/B scales. + // The widen/mul correction sequence intentionally matches the proven m1 Q5_1 path. + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vwaddu.vx v28, v3, x0 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v12, v28, s1 \n\t" + "vwmul.vx v14, v28, s2 \n\t" + "vwmul.vx v20, v28, s3 \n\t" + "vwmul.vx v22, v28, s4 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v12 \n\t" + "vadd.vv v25, v25, v14 \n\t" + "vadd.vv v26, v26, v20 \n\t" + "vadd.vv v27, v27, v22 \n\t" + "vfcvt.f.x.v v12, v24 \n\t" + "vfcvt.f.x.v v14, v25 \n\t" + "vfcvt.f.x.v v20, v26 \n\t" + "vfcvt.f.x.v v22, v27 \n\t" + "vfmul.vv v12, v12, v18 \n\t" + "vfmul.vv v14, v14, v18 \n\t" + "vfmul.vv v20, v20, v18 \n\t" + "vfmul.vv v22, v22, v18 \n\t" + "vfmacc.vf v4, fa0, v12 \n\t" + "vfmacc.vf v5, fa1, v14 \n\t" + "vfmacc.vf v6, fa2, v20 \n\t" + "vfmacc.vf v7, fa3, v22 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { + // Q5_0 block layout per K32/N32 tile: + // [scale_fp16 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + // There is no explicit zp stream; the implicit midpoint correction is +16. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 192 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + "addi %[B], %[B], 704 \n\t" + + // Decode Q5 payload once for the whole tile and expand it into the vmadot layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * 16 (implicit Q5_0 midpoint correction). + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "slli s1, s1, 4 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "slli s2, s2, 4 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "slli s3, s3, 4 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "slli s4, s4, 4 \n\t" + "vadd.vx v24, v24, s1 \n\t" + "vadd.vx v25, v25, s2 \n\t" + "vadd.vx v26, v26, s3 \n\t" + "vadd.vx v27, v27, s4 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i8_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 2048bit + // 4VRF, N32K32, 8192bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... + + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*64 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4 \n\t" + "vl4r.v v8, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*8 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "addi s2, s2, 6+32 \n\t" // AScale + Asum(FP32+i16) + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v28, v8, v9, 1 \n\t" + "vupack.vv v30, v10, v11, 1 \n\t" + + "vslidedown.vi v4, v3, 4 \n\t" + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v24, i8 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadot v18, v3, v26, i8 \n\t" // M0 N8 - N15 + "vmadot v20, v3, v28, i8 \n\t" // M0 N16 - N23 + "vmadot v22, v3, v30, i8 \n\t" // M0 N24 - N31 + + "vmadot v16, v4, v25, i8 \n\t" + "vmadot v18, v4, v27, i8 \n\t" + "vmadot v20, v4, v29, i8 \n\t" + "vmadot v22, v4, v31, i8 \n\t" + + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i8_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = k_blks * sizeof(ggml_fp16_t) + k_blks * blk_len; + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16+8 \n\t" // Ascl+Asum; FP32*4+i16*4 + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + "vl4r.v v8, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v2, v0, v0, 1 \n\t" + + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v4, v8, v9, 1 \n\t" + "vupack.vv v6, v10, v11, 1 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot v16, v2, v24, i8 \n\t" + "vmadot v18, v2, v26, i8 \n\t" + "vmadot v20, v2, v4, i8 \n\t" + "vmadot v22, v2, v6, i8 \n\t" + "vmadot v16, v3, v25, i8 \n\t" + "vmadot v18, v3, v27, i8 \n\t" + "vmadot v20, v3, v5, i8 \n\t" + "vmadot v22, v3, v7, i8 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz %[BK], BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +} + +void moe_m2_gemm_kernel_i8i4_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, + ldc); +#else + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.x v10, t1 \n\t" + "vmv.v.x v11, t2 \n\t" + + "vpack.vv v18, v10, v11, 1 \n\t" + "vsll.vi v18, v18, 3 \n\t" // mul 8 + "vfcvt.f.x.v v18, v18 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vor.vv v19, v18, v18 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { +# if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +# else + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v19, v19 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v21, v21, v21 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + // asum*zp + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v2, v10, t1 \n\t" + "vwmul.vx v4, v10, t2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "vfcvt.f.x.v v2, v2 \n\t" + "vfcvt.f.x.v v4, v4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwcvt.f.f.v v16, v20 \n\t" + + "vfwcvt.f.f.v v18, v14 \n\t" + + // +asum*zp + "vsetvli t0, x0, e32, m1 \n\t" + "vfadd.vv v16, v16, v2 \n\t" + "vfadd.vv v17, v17, v4 \n\t" + "vfmul.vv v16, v16, v18 \n\t" + "vfmul.vv v17, v17, v18 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +# endif + } +#endif +} + +void moe_m2_gemm_kernel_i8i5_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + GGML_UNUSED(blk_len); + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/Bh/Bs and advance to the next q5_0 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" // v0 = Bh N32K32 1-bit packed + "addi %[B], %[B], 128 \n\t" + "vl4r.v v8, (%[B]) \n\t" // v8..v11 = Bs N32K32 i4 + "addi %[B], %[B], 512 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + "slli t3, t3, 4 \n\t" + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + "vadd.vx v25, v25, t3 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v1 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + "addi t5, %[B], 64 \n\t" // t5 = zp (32B) + "addi t6, %[B], 96 \n\t" // t6 = qh (128B) + "addi s1, %[B], 224 \n\t" // s1 = qs (512B) + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/zp/Bh/Bs and advance to the next q5_1 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 736 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t6) \n\t" // v0 = Bh N32K32 1-bit packed + "addi t6, t6, 736 \n\t" + "vl4r.v v8, (s1) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s1, s1, 736 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // q5_1 uses explicit zp, so keep a_sum unshifted here. + // [4*32]*2 + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (t5) \n\t" // v4 = Bzp N32 uint8 + "addi t5, t5, 736 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v4, x0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + "vwmul.vx v31, v28, t3 \n\t" + + // b_scale fp16 -> fp32 + "vfwcvt.f.f.v v28, v1 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + "vadd.vv v25, v25, v31 \n\t" + + // a_scale * b_scale; + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } +} + +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i2k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i2k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i2k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, + ldc); +#else + gemm_kernel_i8i2k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i3k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i3k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + moe_m2_gemm_kernel_i8i4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i8_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i8_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + //moe_m2_gemm_kernel_i8mxfp4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i5_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i5_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i5_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + moe_m2_gemm_kernel_i8i5_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 2; +} + +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.cpp b/ggml/src/ggml-cpu/spacemit/ime_env.cpp new file mode 100644 index 00000000000..a13ba391da2 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.cpp @@ -0,0 +1,320 @@ +#include "ime_env.h" + +#include "ggml-impl.h" +#include "spine_mem_pool.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +bool spine_core_info::get_spine_core_info(std::vector & result) { + static std::unordered_map spine_march_mapping_ = { + {0x8000000058000001, spine_core_arch_id::core_arch_x60 }, + { 0x8000000041000001, spine_core_arch_id::core_arch_a60 }, + { 0x8000000058000002, spine_core_arch_id::core_arch_x100}, + { 0x8000000041000002, spine_core_arch_id::core_arch_a100}, + }; + + result.clear(); + std::ifstream file("/proc/cpuinfo"); + std::string line; + + std::vector> cpu_info_list; + + uint64_t current_processor = spine_invalid_core_id; + uint64_t current_marchid = 0; + bool has_processor = false; + bool has_marchid = false; + + if (!file.is_open()) { + return false; + } + + while (std::getline(file, line)) { + if (line.substr(0, 9) == "processor") { + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + current_processor = std::stoi(line.substr(colon_pos + 1)); + has_processor = true; + } + + has_marchid = false; + } else if (line.substr(0, 7) == "marchid") { + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + std::string marchid_str = line.substr(colon_pos + 1); + marchid_str.erase(std::remove_if(marchid_str.begin(), marchid_str.end(), isspace), marchid_str.end()); + current_marchid = std::stoull(marchid_str, nullptr, 16); + has_marchid = true; + } + } + } + + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + if (has_processor && has_marchid) { + for (auto & cpu_info : cpu_info_list) { + if (cpu_info[0] != spine_invalid_core_id && + spine_march_mapping_.find(cpu_info[1]) != spine_march_mapping_.end()) { + auto core_info = spine_core_info(); + core_info.core_id = cpu_info[0]; + core_info.arch_id = spine_core_arch_id(spine_march_mapping_[cpu_info[1]]); + + result.push_back(core_info); + } + } + } + + return has_processor && has_marchid; +} + +namespace { +uint16_t hex_string_to_u16(const std::string & hex_str) { + try { + size_t pos = 0; + if (hex_str.substr(0, 2) == "0x" || hex_str.substr(0, 2) == "0X") { + pos = 2; + } + unsigned long result = std::stoul(hex_str.substr(pos), nullptr, 16); + if (result > std::numeric_limits::max()) { + throw std::out_of_range("Converted value is out of range for uint16_t"); + } + return static_cast(result); + } catch (const std::invalid_argument & e) { + throw std::invalid_argument("Invalid hexadecimal string"); + } catch (const std::out_of_range & e) { + throw; + } +} + +const char * spine_mem_pool_backend_to_string(spine_mem_pool_backend backend) { + switch (backend) { + case spine_mem_pool_backend::none: + return "NONE"; + case spine_mem_pool_backend::posix_memalign: + return "POSIX"; + case spine_mem_pool_backend::transparent_hugepage: + return "HPAGE"; + case spine_mem_pool_backend::hugetlb_1g: + return "HPAGE1GB"; + } + + return "unknown"; +} + +spine_mem_pool_backend parse_mem_backend(const char * mem_backend_str) { + if (mem_backend_str == nullptr || mem_backend_str[0] == '\0') { + return spine_mem_pool_backend::transparent_hugepage; + } + + std::string value(mem_backend_str); + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char ch) { return static_cast(std::tolower(ch)); }); + + if (value == "none") { + return spine_mem_pool_backend::none; + } + + if (value == "posix") { + return spine_mem_pool_backend::posix_memalign; + } + + if (value == "hpage") { + return spine_mem_pool_backend::transparent_hugepage; + } + + if (value == "hpage1gb") { + return spine_mem_pool_backend::hugetlb_1g; + } + + throw std::runtime_error("invalid SPACEMIT_MEM_BACKEND: " + value + ", expected NONE, POSIX, HPAGE or HPAGE1GB"); +} +} // namespace + +spine_env_info::spine_env_info() { + num_cores = static_cast(std::thread::hardware_concurrency()); + spine_core_info::get_spine_core_info(core_info_list); + + // special for x60 K1 + if (core_info_list.size() == 8 && core_info_list[0].arch_id == spine_core_arch_id::core_arch_x60) { + for (int i = 0; i < 4; i++) { + core_info_list[i].arch_id = spine_core_arch_id::core_arch_a60; + } + } + + // special for qemu + if (core_info_list.size() == 0) { + char * spine_core_arch_str = getenv("SPACEMIT_CORE_ARCH"); + if (spine_core_arch_str != nullptr) { + auto arch_id = hex_string_to_u16(spine_core_arch_str); + for (int i = 0; i < num_cores; i++) { + auto core_info = spine_core_info(); + core_info.core_id = i; + core_info.arch_id = spine_core_arch_id{ arch_id }; + core_info_list.push_back(core_info); + } + } + } + + if (core_info_list.size() == 0) { + throw std::runtime_error( + "Failed to get SPACEMIT_CORE_ARCH from environment or failed to parse it from /proc/cpuinfo"); + } + + char * spine_perfer_core_arch_str = getenv("SPACEMIT_PERFER_CORE_ARCH"); + if (spine_perfer_core_arch_str != nullptr && spine_perfer_core_arch_str != "") { + perfer_core_arch_id = spine_core_arch_id{ hex_string_to_u16(spine_perfer_core_arch_str) }; + } + + char * spine_perfer_core_id_str = getenv("SPACEMIT_PERFER_CORE_ID"); + std::vector perfer_core_id_vec; + if (spine_perfer_core_id_str != nullptr && spine_perfer_core_id_str != "") { + std::string perfer_core_id_str(spine_perfer_core_id_str); + size_t start = 0; + size_t end = 0; + while ((end = perfer_core_id_str.find(',', start)) != std::string::npos) { + std::string core_id_substr = perfer_core_id_str.substr(start, end - start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + start = end + 1; + } + std::string core_id_substr = perfer_core_id_str.substr(start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + } + + perfer_core_ids.reserve(num_cores); + if (perfer_core_arch_id == spine_core_arch_id::core_arch_none) { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + num_perfer_cores++; + perfer_core_arch_id = core_arch_id; + cpu_mask |= (1ULL << core_info.core_id); + perfer_core_ids.push_back(core_info.core_id); + } + } + } else { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + num_perfer_cores++; + cpu_mask |= (1ULL << core_info.core_id); + + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + perfer_core_ids.push_back(core_info.core_id); + } + } + } + if (num_perfer_cores == 0) { + GGML_ABORT("can not find core with arch id %x for SPACEMIT_PERFER_CORE_ARCH in core info list\n", + (uint16_t) perfer_core_arch_id); + } + } + + if (perfer_core_id_vec.size() > 0) { + perfer_core_ids.clear(); + cpu_mask = 0; + num_perfer_cores = 0; + for (int core_id : perfer_core_id_vec) { + if (core_id < 0 || core_id >= num_cores) { + GGML_ABORT("invalid core id in SPACEMIT_PERFER_CORE_ID: %d, should be between 0 and %d\n", core_id, + num_cores - 1); + } + auto core_info = core_info_list[core_id]; + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + cpu_mask |= (1ULL << core_id); + perfer_core_ids.push_back(core_id); + } else { + GGML_ABORT( + "core id %d in SPACEMIT_PERFER_CORE_ID has arch id %x which does not match " + "SPACEMIT_PERFER_CORE_ARCH %x\n", + core_id, (uint16_t) core_arch_id, (uint16_t) perfer_core_arch_id); + } + } + std::string perfer_core_id_vec_str; + for (int core_id : perfer_core_id_vec) { + perfer_core_id_vec_str += std::to_string(core_id) + ","; + } + perfer_core_id_vec_str.pop_back(); + GGML_LOG_DEBUG("SPACEMIT_PERFER_CORE_ID is set, perferred core ids: %s\n", perfer_core_id_vec_str.c_str()); + num_perfer_cores = static_cast(perfer_core_id_vec.size()); + } + + use_ime1 = perfer_core_arch_id == spine_core_arch_id::core_arch_a60 || + perfer_core_arch_id == spine_core_arch_id::core_arch_x100; + + use_ime2 = perfer_core_arch_id == spine_core_arch_id::core_arch_a100; + + mem_backend = parse_mem_backend(getenv("SPACEMIT_MEM_BACKEND")); + char * spine_disable_tcm_str = getenv("SPACEMIT_DISABLE_TCM"); + auto user_disable_tcm = spine_disable_tcm_str != nullptr && strcmp(spine_disable_tcm_str, "0") != 0; + + if (!user_disable_tcm) { + spine_mem_pool_tcm_info tcm_info; + if (spine_mem_pool_tcm_init(&tcm_info)) { + use_tcm = tcm_info.available; + tcm_blk_size = tcm_info.blk_size; + GGML_LOG_DEBUG("CPU_RISCV64_SPACEMIT: tcm is available, blk_size: %zu, blk_num: %zu, is_fake_tcm: %d\n", + tcm_info.blk_size, tcm_info.blk_num, tcm_info.is_fake_tcm); + + for (auto & core_info : core_info_list) { + auto core_arch_head = (uint16_t) (core_info.arch_id) >> 12; + if (core_arch_head != 0xA) { + aicpu_id_offset++; + } else { + break; + } + } + } + } + + GGML_LOG_DEBUG( + "CPU_RISCV64_SPACEMIT: num_cores: %d, num_perfer_cores: %d, perfer_core_arch_id: %x, exclude_main_thread: %d, " + "use_ime1: %d, use_ime2: %d, mem_backend: %s, cpu_mask: %lx, aicpu_id_offset: %d\n", + num_cores, num_perfer_cores, (uint16_t) perfer_core_arch_id, exclude_main_thread, use_ime1, use_ime2, + spine_mem_pool_backend_to_string(mem_backend), cpu_mask, aicpu_id_offset); + + const size_t init_barrier_size = sizeof(spine_barrier_t) * spine_init_barrier_count; + init_barrier = + static_cast(spine_mem_pool_shared_mem_alloc(init_barrier_size, alignof(spine_barrier_t))); + if (init_barrier != nullptr) { + init_barrier_is_shared_mem = true; + } else { + GGML_LOG_WARN("CPU_RISCV64_SPACEMIT: failed to allocate init_barrier from shared mem, falling back to heap\n", + __func__); + init_barrier = new spine_barrier_t[spine_init_barrier_count]; + } + + spine_barrier_init(init_barrier, spine_init_barrier_count, 2); +} + +spine_env_info::~spine_env_info() { + if (init_barrier_is_shared_mem) { + spine_mem_pool_shared_mem_free(init_barrier); + } else { + delete[] init_barrier; + } + + init_barrier = nullptr; + init_barrier_is_shared_mem = false; +} + +spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.h b/ggml/src/ggml-cpu/spacemit/ime_env.h new file mode 100644 index 00000000000..a6ca06d26a4 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.h @@ -0,0 +1,55 @@ +#pragma once + +#include "spine_barrier.h" +#include "spine_mem_pool.h" + +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +constexpr uint64_t spine_invalid_core_id = 0xFFFFFFFF; +constexpr size_t spine_init_barrier_count = 16; + +enum class spine_core_arch_id : uint16_t { + core_arch_none = 0, + core_arch_x60 = 0x503C, + core_arch_x100 = 0x5064, + core_arch_x200 = 0x50C8, + core_arch_a60 = 0xA03C, + core_arch_a100 = 0xA064, + core_arch_a200 = 0xA0C8, +}; + +struct spine_core_info { + uint64_t core_id{ spine_invalid_core_id }; + spine_core_arch_id arch_id{ spine_core_arch_id::core_arch_none }; + + static bool get_spine_core_info(std::vector & result); +}; + +struct spine_env_info { + std::vector core_info_list; + std::vector perfer_core_ids; + int aicpu_id_offset{ 0 }; + int num_cores{ 0 }; + int num_perfer_cores{ 0 }; + spine_core_arch_id perfer_core_arch_id{ spine_core_arch_id::core_arch_none }; + bool exclude_main_thread{ false }; + bool use_ime2{ false }; + bool use_ime1{ false }; + bool use_tcm{ false }; + spine_mem_pool_backend mem_backend{ spine_mem_pool_backend::transparent_hugepage }; + uint64_t tcm_blk_size{ 0 }; + uint64_t cpu_mask{ 0 }; + spine_barrier_t * init_barrier{ nullptr }; + bool init_barrier_is_shared_mem{ false }; + + spine_env_info(); + ~spine_env_info(); +}; + +extern spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h index 75706341505..0a1fafffb25 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -1,26 +1,189 @@ #pragma once +#include #include +#include + +namespace spacemit_kernels { + +#define BLOCK_QNK_LEN 256 + +template struct nrow_block_q2_k { + // [4bit scale + 4bit zp] * N * 16 + uint8_t scales[N * BLOCK_QNK_LEN / 16]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; + uint16_t zeros16[N]; +}; + +template struct nrow_block_q3_k { + // [8bit scale] * N * 16 + int8_t scales[N * 16]; + // [b0, b1, b2, b3, b4, b5, b6, b7] ... [b248, b249, b250, b251, b252, b253, b254, b255] + uint8_t hmask[N * BLOCK_QNK_LEN / 8]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; +}; + +template struct nrow_block_mxfp4 { + uint8_t e[N]; + uint8_t qh[4 * N]; + uint8_t qs[16 * N]; +}; + +template struct __attribute__((packed)) nrow_block_q5_1 { + uint16_t scales16[N]; + uint8_t zp[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_1<1>) == sizeof(uint8_t) + 22, "wrong nrow_block_q5_1 block size/padding"); + +template struct __attribute__((packed)) nrow_block_q5_0 { + uint16_t scales16[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_0<1>) == 22, "wrong nrow_block_q5_0 block size/padding"); + +using gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t *, const uint8_t *, const uint8_t *, float *, size_t, size_t, size_t, size_t)>; + +using moe_gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t **, const uint8_t *, const uint8_t *, float **, size_t, size_t, size_t, size_t)>; -namespace sqnbitgemm_spacemit_ime { namespace ime1 { -size_t gemm_kernel_i8i4(size_t blk_len, - const std::byte * quant_a_ptr, - const std::byte * quant_b_data, - const float * quant_b_scale, - const std::byte * quant_b_zp, - float * c_ptr, - size_t count_m, - size_t count_n, - size_t count_k, - size_t block_count_k, - size_t ldc, - const float * bias, - const size_t scale_stride); - -void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); - -void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime + +namespace ime2 { +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/repack.cpp b/ggml/src/ggml-cpu/spacemit/repack.cpp new file mode 100644 index 00000000000..3c879c4b7a0 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.cpp @@ -0,0 +1,1795 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "repack.h" + +#include "ggml-common.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ime_kernels.h" + +#include +#include +#include +#include + +// clang-format off +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#else +#include +#endif + +#if !defined(__riscv_zfh) +#error "riscv zfh extension not enabled" +#endif + +#else +#error "riscv not enabled in this build" +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +// clang-format on + +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), + "wrong block_with_zp<4,16> size/padding"); + +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +static_assert(sizeof(block<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<4,32> size/padding"); +static_assert(sizeof(block_with_zp<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16 + 32 * sizeof(uint8_t), + "wrong block_with_zp<4,32> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +using block_q4_0x32 = block<4, 32>; +using block_q4_1x32 = block_with_zp<4, 32>; +using block_q8_0x32 = block<8, 32>; + +struct block_q4_0x32x256 { + block_q4_0x32 blocks[8]; // [f16 * 32 | i4 * 32 * 32] * 8 +}; + +struct block_q4_1x32x256 { + block_q4_0x32 blocks[8]; + uint8_t zps[32 * 8]; +}; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, + const uint8_t * GGML_RESTRICT q, + uint8_t * GGML_RESTRICT d, + uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_q4_0x32 make_block_q4_0x32(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x32 out; + assert(QK4_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b17] ......... [b30 b31] + out.qs[i * QK4_0 / 2 + QK4_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q4_1x32 make_block_q4_1x32(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x32 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[i * QK4_1 / 2 + QK4_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q8_0x32 make_block_q8_0x32(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x32 out; + GGML_ASSERT(QK8_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * QK8_0, in[i].qs, QK8_0); + } + + return out; +} + +static int repack_q2_k_to_q2_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const block_q2_K * src = (const block_q2_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q2_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint8_t qs_aux[256] = { 0 }; + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q2_K * src_block = &src[(b + i) * nblocks + x]; + + // scale for [16, N] + for (int j = 0; j < 16; j++) { + auto zp_aux = (dst->scales[j * nrows_interleaved + i]) & 0xF0; + + dst->scales[j * nrows_interleaved + i] = (src_block->scales[j] & 0x0F) | zp_aux; + } + + // zp for [N, 16] + for (int j = 0; j < 16; j++) { + auto scale_aux = (dst->scales[16 * i + j]) & 0x0F; + + dst->scales[16 * i + j] = (src_block->scales[j] & 0xF0) | scale_aux; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + dst->scales16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + dst->zeros16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + dst++; + } + } + + return 0; +} + +static int repack_q3_k_to_q3_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q3_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * src = (const block_q3_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q3_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q3_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint32_t b_scale_aux[4] = { 0 }; + uint8_t qs_aux[256] = { 0 }; + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q3_K * src_block = &src[(b + i) * nblocks + x]; + + uint32_t * auxs = b_scale_aux; + int8_t * scale = (int8_t *) auxs; + memcpy(auxs, src_block->scales, 12); + + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + for (int j = 0; j < 16; j++) { + dst->scales[j * nrows_interleaved + i] = scale[j] - 32; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + //memcpy(dst->hmask + i * 32, src_block->hmask, 32); + + // from nrows_interleaved * [32byte] + // to 16 * [nrows_interleaved * uint16_t] + uint16_t * dst_mask = ((uint16_t *) dst->hmask) + i; + for (int j = 0; j < 16; j++, dst_mask += nrows_interleaved) { + uint8_t b_shift = j / 2; + uint8_t * b_mask_col = (uint8_t *) (src_block->hmask + (j % 2) * 16); + // b0 - b15 + uint16_t msk_out_0 = 0; + + for (int k = 0; k < 8; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + for (int k = 8; k < 16; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + + dst_mask[0] = msk_out_0; + } + + dst->scales16[i] = src_block->d; + } + + dst++; + } + } + + return 0; +} + +static int repack_q4_0_to_q4_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32x256 * dst = (block_q4_0x32x256 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + dst->blocks[j] = make_block_q4_0x32(dst_tmp, interleave_block); + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_1_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32x256 * dst = (block_q4_1x32x256 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + + block_q4_0x32 * dst_block = &dst->blocks[j]; + uint8_t * dst_zp = dst->zps + j * nrows_interleaved; + + for (int i = 0; i < nrows_interleaved; i++) { + float d = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + + dst_block->d[i] = GGML_FP32_TO_FP16(d); + dst_zp[i] = static_cast(mid); + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + k] = + (dst_tmp[i].qs[k * 2] & 0x0F) | ((dst_tmp[i].qs[k * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + QK4_1 / 4 + k] = + ((dst_tmp[i].qs[k * 2] & 0xF0) >> 4) | (dst_tmp[i].qs[k * 2 + 1] & 0xF0); + } + } + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_0_to_q4_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack. +static int repack_q4_0_to_q4_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_0 / 2; // 16 + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + // d is at offset 0 of each block_q4_0, stride between rows = row_stride + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + // For each row i: + // src qs[16]: [b0|b16] [b1|b17] ... [b15|b31] (lo nibble = b_j, hi nibble = b_{j+16}) + // dst qs low 8B: (qs[2j] & 0x0F) | ((qs[2j+1] & 0x0F) << 4) for j=0..7 + // dst qs high 8B: ((qs[2j] >> 4)) | (qs[2j+1] & 0xF0) for j=0..7 + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); // qs[0], qs[2], ..., qs[14] + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); // qs[1], qs[3], ..., qs[15] + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_1_to_q4_1_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack + zp computation. +static int repack_q4_1_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_1 / 2; // 16 + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_1); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_1 * col_src = src + x; + + // --- 1) Gather d and m, compute zp = clamp(nearbyint(-m/d), 0, 15) --- + // block_q4_1 layout: [d(f16), m(f16), qs[16]] + // d is at byte offset 0, m is at byte offset 2 from each block start + { + const uint8_t * dm_base = (const uint8_t *) &col_src->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + ggml_half * d_dst = dst->d; + uint8_t * zp_dst = dst->zp; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + + // stride load d (f16) from each row + vuint16m1_t vd_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd_raw, vl); + + // stride load m (f16) from each row (offset +2 bytes from d) + vuint16m1_t vm_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + 2 + offset * row_stride), row_stride, vl); + + // convert to f32 for zp computation: zp = nearbyint(-m / d) + vfloat16m1_t vd_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vd_raw); + vfloat16m1_t vm_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vm_raw); + + // -m / d in f16 directly (SpaceMIT X60 supports f16 arithmetic) + vfloat16m1_t v_neg_m = __riscv_vfneg_v_f16m1(vm_f16, vl); + vfloat16m1_t v_ratio = __riscv_vfdiv_vv_f16m1(v_neg_m, vd_f16, vl); + + // Convert to f32 for nearbyint, then clamp + vfloat32m2_t v_ratio_f32 = __riscv_vfwcvt_f_f_v_f32m2(v_ratio, vl); + + // Use integer rounding: convert f32 -> int (rounds to nearest) + vint32m2_t v_zp_i32 = __riscv_vfcvt_x_f_v_i32m2(v_ratio_f32, vl); + + // clamp to [0, 15] + v_zp_i32 = __riscv_vmax_vx_i32m2(v_zp_i32, 0, vl); + v_zp_i32 = __riscv_vmin_vx_i32m2(v_zp_i32, 15, vl); + + // narrow i32 -> u8 + vint16m1_t v_zp_i16 = __riscv_vncvt_x_x_w_i16m1(v_zp_i32, vl); + vint8mf2_t v_zp_i8 = __riscv_vncvt_x_x_w_i8mf2(v_zp_i16, vl); + vuint8mf2_t v_zp_u8 = __riscv_vreinterpret_v_i8mf2_u8mf2(v_zp_i8); + __riscv_vse8_v_u8mf2(zp_dst + offset, v_zp_u8, vl); + + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_k_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q6_k_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q8_0 dst_tmp[32]; + int8_t aux8[QK4_1]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrow_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + int i = 0; + for (; i < nrow_real; i++) { + const uint8_t * q4 = src[x + i * nblocks].ql; + const uint8_t * qh = src[x + i * nblocks].qh; + const int8_t * scales = src[x + i * nblocks].scales; + float d = GGML_FP16_TO_FP32(src[x + i * nblocks].d); + + q4 += 64 * (bi / 4); + qh += 32 * (bi / 4); + int8_t * GGML_RESTRICT a = aux8; + + int8_t bi_idx = bi % 4; + + if (bi_idx == 0) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + } + } else if (bi_idx == 1) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + } + } else if (bi_idx == 2) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + } + } else if (bi_idx == 3) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + } + a = aux8; + + float a_max_abs = 0.0f; + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + for (int l = 0; l < 16; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_0)); + } + + for (int l = 16; l < 32; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_1)); + } + + float reflect_scale = a_max_abs / ((1 << 7) - 1); + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + for (int l = 0; l < 16; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_0), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + for (int l = 16; l < 32; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_1), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + dst_tmp[i].d = GGML_FP32_TO_FP16(reflect_scale); + + memcpy(dst_tmp[i].qs, a, 32 * sizeof(int8_t)); + } + + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q6_k_to_q8_0_32_bl +// Vectorizes the Q6_K dequant -> requant pipeline using RVV intrinsics. +// For each sub-block (bi), dequant 32 Q6_K values to int6 -> apply two sub-block scales -> +// find max abs -> compute reflect_scale -> requant to int8 -> gather d with stride load. +static int repack_q6_k_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q6_K); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + // --- 1) Gather 32 d values with stride load --- + // We need to compute reflect_scale per row first, so gather d later. + // Process each row: dequant Q6_K sub-block -> requant to Q8_0 + for (int i = 0; i < nrows_interleaved; i++) { + const block_q6_K * src_blk = &src[x + i * nblocks]; + const uint8_t * q4 = src_blk->ql + 64 * (bi / 4); + const uint8_t * qh = src_blk->qh + 32 * (bi / 4); + const int8_t * scales = src_blk->scales; + float d = GGML_FP16_TO_FP32(src_blk->d); + + int8_t bi_idx = bi % 4; + + // --- Dequant 32 Q6_K values to int6 (range [-32, 31]) using RVV --- + // vl = 32 for e8m2 (VLEN=256) or loop for smaller VLEN + const size_t vl16 = __riscv_vsetvl_e8m1(16); + + vint8m1_t va_lo, va_hi; // 16 elements each + + if (bi_idx == 0) { + // a[l] = (q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_lo, 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_hi, 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 1) { + // a[l] = (q4[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 2, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 2, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 2) { + // a[l] = (q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 4, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 4, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else { // bi_idx == 3 + // a[l] = (q4[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 6, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 6, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } + + // --- Widen to i16 for scaled abs computation --- + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + + // Widen i8 -> i16 -> f32 for abs*scale computation + vint16m2_t va_lo_w = __riscv_vsext_vf2_i16m2(va_lo, vl16); + vint16m2_t va_hi_w = __riscv_vsext_vf2_i16m2(va_hi, vl16); + + // Compute |a[l] * scale_0| for lo half, |a[l] * scale_1| for hi half + vfloat32m4_t vf_lo = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_lo_w, vl16), vl16); + vfloat32m4_t vf_hi = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_hi_w, vl16), vl16); + + vfloat32m4_t vabs_lo = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_lo, scale_0, vl16), vl16); + vfloat32m4_t vabs_hi = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_hi, scale_1, vl16), vl16); + + // Find max abs across both halves + vfloat32m4_t vabs_max = __riscv_vfmax_vv_f32m4(vabs_lo, vabs_hi, vl16); + + // Reduce to scalar max + vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m1_t vmax_red = __riscv_vfredmax_vs_f32m4_f32m1(vabs_max, vzero, vl16); + float a_max_abs = __riscv_vfmv_f_s_f32m1_f32(vmax_red); + + float reflect_scale = a_max_abs / 127.0f; + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + // --- Requant: a[l] = clamp(nearbyint(a[l] * reflect_scale_x), -128, 127) --- + vfloat32m4_t vscaled_lo = __riscv_vfmul_vf_f32m4(vf_lo, reflect_scale_0, vl16); + vfloat32m4_t vscaled_hi = __riscv_vfmul_vf_f32m4(vf_hi, reflect_scale_1, vl16); + + // fcvt.x rounds to nearest (using current rounding mode) + vint32m4_t vi_lo = __riscv_vfcvt_x_f_v_i32m4(vscaled_lo, vl16); + vint32m4_t vi_hi = __riscv_vfcvt_x_f_v_i32m4(vscaled_hi, vl16); + + // Clamp to [-128, 127] + vi_lo = __riscv_vmax_vx_i32m4(vi_lo, -128, vl16); + vi_lo = __riscv_vmin_vx_i32m4(vi_lo, 127, vl16); + vi_hi = __riscv_vmax_vx_i32m4(vi_hi, -128, vl16); + vi_hi = __riscv_vmin_vx_i32m4(vi_hi, 127, vl16); + + // Narrow i32 -> i16 -> i8 + vint16m2_t vi16_lo = __riscv_vncvt_x_x_w_i16m2(vi_lo, vl16); + vint16m2_t vi16_hi = __riscv_vncvt_x_x_w_i16m2(vi_hi, vl16); + vint8m1_t vi8_lo = __riscv_vncvt_x_x_w_i8m1(vi16_lo, vl16); + vint8m1_t vi8_hi = __riscv_vncvt_x_x_w_i8m1(vi16_hi, vl16); + + // Store d and qs directly into dst block + dst->d[i] = GGML_FP32_TO_FP16(reflect_scale); + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + __riscv_vse8_v_i8m1(dq, vi8_lo, vl16); + __riscv_vse8_v_i8m1(dq + 16, vi8_hi, vl16); + } + dst++; + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q8_0_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[0] % QK8_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrows_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + int i = 0; + for (; i < nrows_real; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q8_0_to_q8_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes scale gather + qs copy. +static int repack_q8_0_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK8_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q8_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q8_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Copy qs for each of the 32 rows (32 bytes per row) --- + { + for (int i = 0; i < 32; i++) { + const int8_t * sq = col_src[i * nblocks].qs; + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + + size_t len = QK8_0; + size_t idx = 0; + while (len > 0) { + size_t vl = __riscv_vsetvl_e8m2(len); + vint8m2_t vs = __riscv_vle8_v_i8m2(sq + idx, vl); + __riscv_vse8_v_i8m2(dq + idx, vs, vl); + idx += vl; + len -= vl; + } + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static void convert_mxfp4_to_5bit(const block_mxfp4 & src, spacemit_kernels::nrow_block_mxfp4<1> & dst) { + dst.e[0] = src.e; + + // Decode all 32 mxfp4 values to signed integers via kvalues_mxfp4 + int8_t vals[32]; + for (int j = 0; j < QK_MXFP4 / 2; j++) { + vals[j] = kvalues_mxfp4[src.qs[j] & 0xF]; + vals[j + QK_MXFP4 / 2] = kvalues_mxfp4[src.qs[j] >> 4]; + } + + // vals [b0, b1, b2, b3, ..., b30, b31] + // Pack abs into qs with reorder: [b0,b1]..[b14,b15]..[b30,b31] + for (int j = 0; j < QK_MXFP4 / 2; j++) { + uint8_t lo0 = static_cast(std::abs(vals[j * 2])); + uint8_t lo1 = static_cast(std::abs(vals[j * 2 + 1])); + dst.qs[j] = (lo0 & 0x0F) | ((lo1 & 0x0F) << 4); + } + + // Pack sign bits into qh[4] (32 bits total, 1 bit per weight) + // reorder: [0,1,2,...,15,16,17,...,31] after the qs reorder above + uint32_t sign_bits = 0; + for (int j = 0; j < 32; j++) { + if (vals[j] < 0) { + sign_bits |= (1u << j); + } + } + memcpy(dst.qh, &sign_bits, 4); +} + +static spacemit_kernels::nrow_block_mxfp4<32> make_block_mxfp4x32(spacemit_kernels::nrow_block_mxfp4<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_mxfp4<32> out; + GGML_ASSERT(QK_MXFP4 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.e[i] = in[i].e[0]; + } + + // qs: copy per-row 16 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * 16, in[i].qs, 16); + } + + // qh: copy per-row 4 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qh + i * 4, in[i].qh, 4); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_mxfp4<32> * dst = (spacemit_kernels::nrow_block_mxfp4<32> *) t->data; + const block_mxfp4 * src = (const block_mxfp4 *) data; + spacemit_kernels::nrow_block_mxfp4<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_MXFP4 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + convert_mxfp4_to_5bit(src[x + i * nblocks], dst_tmp[i]); + } + *dst++ = make_block_mxfp4x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static spacemit_kernels::nrow_block_q5_1<32> make_block_q5_1x32(spacemit_kernels::nrow_block_q5_1<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_1<32> out; + GGML_ASSERT(QK5_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + out.zp[i] = in[i].zp[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + QK5_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static spacemit_kernels::nrow_block_q5_0<32> make_block_q5_0x32(spacemit_kernels::nrow_block_q5_0<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_0<32> out; + GGML_ASSERT(QK5_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + QK5_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static int repack_q5_0_to_q5_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_0<32> * dst = (spacemit_kernels::nrow_block_q5_0<32> *) t->data; + const block_q5_0 * src = (const block_q5_0 *) data; + spacemit_kernels::nrow_block_q5_0<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_0 & s = src[x + i * nblocks]; + + dst_tmp[i].scales16[0] = s.d; + memcpy(dst_tmp[i].qs, s.qs, sizeof(dst_tmp[i].qs)); + memcpy(dst_tmp[i].qh, s.qh, sizeof(dst_tmp[i].qh)); + } + *dst++ = make_block_q5_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_1_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_1 * src = (const block_q5_1 *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_1 & s = src[x + i * nblocks]; + + float d = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + + if (d == 0.0f) { + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(std::fabs(m)); + dst_tmp[i].zp[0] = m < 0.0f ? 1 : 0; + memset(dst_tmp[i].qh, 0, sizeof(dst_tmp[i].qh)); + memset(dst_tmp[i].qs, m > 0.0f ? 0x11 : 0x00, sizeof(dst_tmp[i].qs)); + continue; + } + + float mid = std::nearbyintf(-m / d); + mid = std::min(31.0f, std::max(0.0f, mid)); + + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d); + dst_tmp[i].zp[0] = static_cast(mid); + + // qs: copy low 4 bits directly (same nibble packing) + memcpy(dst_tmp[i].qs, s.qs, QK5_1 / 2); + + // qh: copy 5th bit directly + memcpy(dst_tmp[i].qh, s.qh, 4); + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_k_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK5_1 == 8); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + + float d1 = d * sc; + float m1 = min * m; + + float mid = std::nearbyintf(m1 / d1); + mid = std::min(31.0f, std::max(0.0f, mid)); + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d1); + dst_tmp[i].zp[0] = static_cast(mid); + + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK5_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + + // Extract the 5th bit (qh) for this sub-block + // block_q5_K.qh[32]: for sub-block j, the 5th bit is at bit position j in qh[l] + // qs was reordered: dst_qs maps to src weights [0,16,1,17,...,15,31] + // So qh must follow the same reorder to stay aligned with qs + // dst qh[4] = 32 bits for 32 weights in the reordered layout: + // byte 0: weights 0..7 (from src_qh[0..7]) + // byte 1: weights 8..15 (from src_qh[8..15]) + // byte 2: weights 16..23 (from src_qh[16..23]) + // byte 3: weights 24..31 (from src_qh[24..31]) + const uint8_t * src_qh = src[x + i * nblocks].qh; + for (int bi = 0; bi < 4; bi++) { + uint8_t qh_byte = 0; + for (int k = 0; k < 8; k++) { + int src_idx = bi * 8 + k; + qh_byte |= ((src_qh[src_idx] >> j) & 1) << k; + } + dst_tmp[i].qh[bi] = qh_byte; + } + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +namespace ggml::cpu::riscv64_spacemit { + +template int repack(ggml_tensor *, const void *, size_t); + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_k_to_q2_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_k_to_q3_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_0_to_q4_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_0_to_q4_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_0_256_32_bl_ref(t, 32, data, data_size); +#else + //return repack_q4_0_to_q4_0_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_1_to_q4_1_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_1_256_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q6_k_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q6_k_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q8_0_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q8_0_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_0_to_q5_0_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_1_to_q5_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_k_to_q5_1_32_bl(t, 32, data, data_size); +} + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/repack.h b/ggml/src/ggml-cpu/spacemit/repack.h new file mode 100644 index 00000000000..950cbde7593 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml-common.h" +#include "ggml.h" + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +template +int repack(ggml_tensor * t, const void * data, size_t data_size); + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp new file mode 100644 index 00000000000..d2f89743622 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp @@ -0,0 +1,3178 @@ +#include "rvv_kernels.h" + +#include "common.h" +#include "ggml.h" +#include "ops.h" +#include "string.h" + +#include +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels::rvv { + +namespace { + +auto align_up(size_t value, size_t alignment) { + return (value + alignment - 1) / alignment * alignment; +} + +static inline bool flash_attn_ext_supported_d_vlen1024_vf16(int64_t d) { + return d > 0 && d <= 128; +} + +static inline bool flash_attn_ext_supported_shape_vlen1024_vf16(int64_t DK, int64_t DV) { + return flash_attn_ext_supported_d_vlen1024_vf16(DK) && flash_attn_ext_supported_d_vlen1024_vf16(DV); +} + +static inline float reduce_sum_f32m4_vlen1024(vfloat32m4_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m4_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +static inline float reduce_sum_f32m2_vlen1024(vfloat32m2_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m2_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +// Adapted from ggml_v_expf_m2 in vec.h. This is accurate enough for softmax. +static inline vfloat32m2_t rvv_expf_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = + __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), u, vl), + u, vl); + + if (!__riscv_vcpop_m_b16(c, vl)) { + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + } + + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = + __riscv_vmerge_vvm_f32m2(__riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), c, vl); + return __riscv_vmerge_vvm_f32m2(r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), vl); +} + +static inline vfloat32m2_t rvv_tanh_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t abs_x = __riscv_vfabs_v_f32m2(x, vl); + const vfloat32m2_t neg_2_abs = __riscv_vfmul_vf_f32m2(abs_x, -2.0f, vl); + const vfloat32m2_t exp_term = rvv_expf_approx_f32m2(neg_2_abs, vl); + const vfloat32m2_t numerator = __riscv_vfsub_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t denominator = __riscv_vfadd_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t tanh_abs = __riscv_vfneg_v_f32m2(__riscv_vfdiv_vv_f32m2(numerator, denominator, vl), vl); + const vbool16_t neg_mask = __riscv_vmflt_vf_f32m2_b16(x, 0.0f, vl); + const vfloat32m2_t tanh_neg = __riscv_vfneg_v_f32m2(tanh_abs, vl); + return __riscv_vmerge_vvm_f32m2(tanh_abs, tanh_neg, neg_mask, vl); +} + +static void rvv_softcap_tanh_inplace_f32(float * dst, int64_t dst_stride, int64_t tile_rows, int64_t n, float softcap) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride) { + float * dst_row = dst; + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m2(remaining); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst_row, vl); + v = rvv_tanh_approx_f32m2(v, vl); + v = __riscv_vfmul_vf_f32m2(v, softcap, vl); + __riscv_vse32_v_f32m2(dst_row, v, vl); + dst_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_softmax_exp_inplace_f32(float * dst, int64_t n, float max_value) { + float row_sum = 0.0f; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst, vl); + v = __riscv_vfsub_vf_f32m2(v, max_value, vl); + v = rvv_expf_approx_f32m2(v, vl); + __riscv_vse32_v_f32m2(dst, v, vl); + row_sum += reduce_sum_f32m2_vlen1024(v, vl); + dst += vl; + n -= vl; + } + return row_sum; +} + +static inline float rvv_add_max_inplace_f32(float * dst, const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline float rvv_softcap_add_max_inplace_f32(float * dst, const float * src, int64_t n, float softcap) { + if (softcap == 0.0f) { + return rvv_add_max_inplace_f32(dst, src, n); + } + + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t vdst = __riscv_vle32_v_f32m2(dst, vl); + vfloat32m2_t vsrc = __riscv_vle32_v_f32m2(src, vl); + vdst = rvv_tanh_approx_f32m2(vdst, vl); + vdst = __riscv_vfmul_vf_f32m2(vdst, softcap, vl); + vdst = __riscv_vfadd_vv_f32m2(vdst, vsrc, vl); + __riscv_vse32_v_f32m2(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m2_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline void rvv_zero_f32(float * dst, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t z = __riscv_vfmv_v_f_f32m4(0.0f, vl); + __riscv_vse32_v_f32m4(dst, z, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_scale_f32(float * dst, float scale, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t v = __riscv_vle32_v_f32m4(dst, vl); + v = __riscv_vfmul_vf_f32m4(v, scale, vl); + __riscv_vse32_v_f32m4(dst, v, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_add_inplace_f32(float * dst, + int64_t dst_stride, + const float * src, + int64_t src_stride, + int64_t tile_rows, + int64_t n) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride, src += src_stride) { + int64_t remaining = n; + float * dst_row = dst; + const float * src_row = src; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst_row, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src_row, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst_row, vdst, vl); + dst_row += vl; + src_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_max_f32(const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t v = __riscv_vle32_v_f32m4(src, vl); + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(v, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + src += vl; + n -= vl; + } + return max_val; +} + +static void rvv_pack_f32_as_scaled_f16(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + _Float16 * dst_row_ptr = (_Float16 *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + const vfloat16m2_t v16 = __riscv_vfncvt_f_f_w_f16m2(v32, vl); + __riscv_vse16_v_f16m2(dst_row_ptr, v16, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f16_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const _Float16 * row_ptr = (const _Float16 *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e16m2(remaining); + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(row_ptr, vl); + vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f32_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float * scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale[tq], vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static inline void rvv_transposed_s32_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e32, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vle32.v v1, (s2) \n\t" + "sh2add s2, t0, s2 \n\t" + "vle32.v v2, (s3) \n\t" + "sh2add s3, t0, s3 \n\t" + "vle32.v v3, (s4) \n\t" + "sh2add s4, t0, s4 \n\t" + "vle32.v v4, (s5) \n\t" + "sh2add s5, t0, s5 \n\t" + "vle32.v v5, (s6) \n\t" + "sh2add s6, t0, s6 \n\t" + "vle32.v v6, (s7) \n\t" + "sh2add s7, t0, s7 \n\t" + "vle32.v v7, (s8) \n\t" + "sh2add s8, t0, s8 \n\t" + "vssseg8e32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 32 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vsse32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 4 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_transposed_s16_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e16, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vle16.v v1, (s2) \n\t" + "sh1add s2, t0, s2 \n\t" + "vle16.v v2, (s3) \n\t" + "sh1add s3, t0, s3 \n\t" + "vle16.v v3, (s4) \n\t" + "sh1add s4, t0, s4 \n\t" + "vle16.v v4, (s5) \n\t" + "sh1add s5, t0, s5 \n\t" + "vle16.v v5, (s6) \n\t" + "sh1add s6, t0, s6 \n\t" + "vle16.v v6, (s7) \n\t" + "sh1add s7, t0, s7 \n\t" + "vle16.v v7, (s8) \n\t" + "sh1add s8, t0, s8 \n\t" + "vssseg8e16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 16 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vsse16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 2 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_qk_dot_tile_f16_x1(float * dst, + const _Float16 * q_row, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc = __riscv_vfwmacc_vf_f32m2(acc, q_row[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst, acc, vl); +} + +static inline void rvv_qk_dot_tile_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const _Float16 * q0, + const _Float16 * q1, + const _Float16 * q2, + const _Float16 * q3, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc0 = __riscv_vfwmacc_vf_f32m2(acc0, q0[d], k_vec, vl); + acc1 = __riscv_vfwmacc_vf_f32m2(acc1, q1[d], k_vec, vl); + acc2 = __riscv_vfwmacc_vf_f32m2(acc2, q2[d], k_vec, vl); + acc3 = __riscv_vfwmacc_vf_f32m2(acc3, q3[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst0, acc0, vl); + __riscv_vse32_v_f32m2(dst1, acc1, vl); + __riscv_vse32_v_f32m2(dst2, acc2, vl); + __riscv_vse32_v_f32m2(dst3, acc3, vl); +} + +static inline void rvv_pv_accumulate_f16_x1(float * dst, + const float * prob, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_pv_accumulate_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const float * prob0, + const float * prob1, + const float * prob2, + const float * prob3, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc0 = __riscv_vle32_v_f32m4(dst0 + d_off, vl); + vfloat32m4_t acc1 = __riscv_vle32_v_f32m4(dst1 + d_off, vl); + vfloat32m4_t acc2 = __riscv_vle32_v_f32m4(dst2 + d_off, vl); + vfloat32m4_t acc3 = __riscv_vle32_v_f32m4(dst3 + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc0 = __riscv_vfmacc_vf_f32m4(acc0, prob0[tk], v32, vl); + acc1 = __riscv_vfmacc_vf_f32m4(acc1, prob1[tk], v32, vl); + acc2 = __riscv_vfmacc_vf_f32m4(acc2, prob2[tk], v32, vl); + acc3 = __riscv_vfmacc_vf_f32m4(acc3, prob3[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst0 + d_off, acc0, vl); + __riscv_vse32_v_f32m4(dst1 + d_off, acc1, vl); + __riscv_vse32_v_f32m4(dst2 + d_off, acc2, vl); + __riscv_vse32_v_f32m4(dst3 + d_off, acc3, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_qk_dot_tile(float * dst, + const float * q_row, + const float * k_pack, + int64_t dk, + int64_t kv_tile, + float scale) { + const size_t vl = __riscv_vsetvl_e32m4(kv_tile); + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat32m4_t k_vec = __riscv_vle32_v_f32m4(k_pack + d * kv_tile, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, q_row[d] * scale, k_vec, vl); + } + + __riscv_vse32_v_f32m4(dst, acc, vl); +} + +static inline void rvv_pv_accumulate(float * dst, + const float * prob, + const float * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e32m4(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat32m4_t v_vec = __riscv_vle32_v_f32m4(v_pack + tk * dv + d_off, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v_vec, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static void permute_transpose_impl(const ggml_tensor * src0, + ggml_tensor * dst, + int64_t batch, + int64_t m, + int64_t n, + int64_t batch_stride, + int64_t m_src_stride, + int64_t n_src_stride, + int64_t n_dst_stride, + int ith, + int nth) { + GGML_ASSERT(n_src_stride == sizeof(int32_t) || n_src_stride == sizeof(int16_t)); + + if (n_src_stride == sizeof(int32_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else if (n_src_stride == sizeof(int16_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else { + GGML_ABORT("not implemented"); + } +} + +template +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow(float ** pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + float ** sinks, + float ** dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK, + void * tcm_buffer, + size_t tcm_buffer_size) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + float S[QLEN] = { 0.0f }; // sum + float M[QLEN] = { -INFINITY }; // maximum KQ value + + _Float16 * kq16_buffer = (_Float16 *) tcm_buffer; + _Float16 * qv_buffer = kq16_buffer + QLEN * DV; + const size_t qkv_temp_buffer_size = (QLEN * DV + QLEN * DK) * sizeof(_Float16); + char * kv_tile_buffer = (char *) (qv_buffer + QLEN * DK); + + { + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + for (int64_t i = 0; i < QLEN; ++i) { + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq[i], DK), DK); + __riscv_vse16_v_f16m2(qv_buffer + i * DK, Q_q_v, DK); + } + } + + const uintptr_t scratch_addr = reinterpret_cast(kv_tile_buffer); + const size_t scratch_size = tcm_buffer_size > qkv_temp_buffer_size ? tcm_buffer_size - qkv_temp_buffer_size : 0; + const uintptr_t kq_tile_addr = align_up(scratch_addr, alignof(float)); + const size_t scratch_prefix = kq_tile_addr - scratch_addr; + const size_t packed_tile_size = + QLEN * sizeof(float) + DK * sizeof(_Float16) + DV * sizeof(_Float16) + sizeof(float); + const int64_t max_ic_tile_step = ((int64_t) __riscv_vsetvlmax_e16m1()) & ~((int64_t) 7); + const int64_t max_fit_by_tcm = + scratch_size > scratch_prefix ? (int64_t) ((scratch_size - scratch_prefix) / packed_tile_size) : 0; + const int64_t ic_tile_step = std::min(max_ic_tile_step, max_fit_by_tcm) & ~((int64_t) 7); + + const uintptr_t k_tile_addr = kq_tile_addr + QLEN * ic_tile_step * sizeof(float); + const uintptr_t v_tile_addr = k_tile_addr + DK * ic_tile_step * sizeof(_Float16); + const uintptr_t mv_tile_addr = v_tile_addr + ic_tile_step * DV * sizeof(_Float16); + + if (ic_tile_step >= 8) { + float * kq_tile_buffer = reinterpret_cast(kq_tile_addr); + _Float16 * k_tile_pack = reinterpret_cast<_Float16 *>(k_tile_addr); + _Float16 * v_tile_pack = reinterpret_cast<_Float16 *>(v_tile_addr); + float * mv_tile_pack = reinterpret_cast(mv_tile_addr); + + const int64_t k_tile_byte_stride = ic_tile_step * (int64_t) sizeof(_Float16); + + int64_t ic_step = 0; + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + if (mv != -INFINITY) { + const _Float16 * k_data = (const _Float16 *) (k_data_row + ic * nbk1); + const _Float16 * v_data = (const _Float16 *) (v_data_row + ic * nbv1); + + const vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2(k_data, DK); + const vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2(v_data, DV); + __riscv_vsse16_v_f16m2(k_tile_pack + ic_step, k_tile_byte_stride, k_data_v, DK); + __riscv_vse16_v_f16m2(v_tile_pack + ic_step * DV, v_data_v, DV); + mv_tile_pack[ic_step] = mv; + ic_step++; + } + + if (ic_step > 0 && (ic_step == ic_tile_step || ic == (nek1 - 1))) { + if constexpr (QLEN == 4) { + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc2 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc3 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + qk_acc2 = __riscv_vfwmacc_vf_f32m2(qk_acc2, qv_buffer[2 * DK + d], k_vec, qk_vl); + qk_acc3 = __riscv_vfwmacc_vf_f32m2(qk_acc3, qv_buffer[3 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + qk_acc2 = __riscv_vfmul_vf_f32m2(qk_acc2, scale, qk_vl); + qk_acc3 = __riscv_vfmul_vf_f32m2(qk_acc3, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 2 * ic_tile_step, qk_acc2, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 3 * ic_tile_step, qk_acc3, qk_vl); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + } + + for (int i = 0; i < QLEN; ++i) { + float * row_ptr = kq_tile_buffer + i * ic_tile_step; + const float tile_max = + rvv_softcap_add_max_inplace_f32(row_ptr, mv_tile_pack, ic_step, logit_softcap); + + const float Mold = M[i]; + + if (tile_max > Mold) { + const float ms = expf(Mold - tile_max); + M[i] = tile_max; + S[i] *= ms; + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, (_Float16) ms, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + } + + S[i] += rvv_softmax_exp_inplace_f32(row_ptr, ic_step, M[i]); + } + + if constexpr (QLEN == 4) { + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + vfloat16m2_t pv_acc2 = __riscv_vle16_v_f16m2(kq16_buffer + 2 * DV, DV); + vfloat16m2_t pv_acc3 = __riscv_vle16_v_f16m2(kq16_buffer + 3 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + pv_acc2 = + __riscv_vfmacc_vf_f16m2(pv_acc2, (_Float16) kq_tile_buffer[2 * ic_tile_step + tk], v16, DV); + pv_acc3 = + __riscv_vfmacc_vf_f16m2(pv_acc3, (_Float16) kq_tile_buffer[3 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 2 * DV, pv_acc2, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 3 * DV, pv_acc3, DV); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + } + + ic_step = 0; + } + } + } else { + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + const char * k_data = k_data_row + ic * nbk1; + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t k_data_v; + vfloat16m2_t v_data_v; + + if (mv != -INFINITY) { + k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + } else { + continue; + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t Q_q_v = __riscv_vle16_v_f16m2(qv_buffer + i * DK, DK); + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + s = s * scale; + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + s += mv; + + const float Mold = M[i]; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + if (s > M[i]) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M[i] = s; + ms = expf(Mold - M[i]); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M[i]); + } + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + S[i] = S[i] * ms + vs; // scale and increment sum with partial sum + } + } + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks[i]) { + const float s = *(sinks[i]); + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[i]) { + ms = expf(M[i] - s); + M[i] = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M[i]); + } + + S[i] = S[i] * ms + vs; + } + + // V /= S + const float S_inv = S[i] == 0.0f ? 0.0f : 1.0f / S[i]; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst[i], VKQ32_v, DV); + } +} + +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1(const float * pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + const float * sinks, + float * dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq, DK), DK); + + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + if (mv == -INFINITY) { + continue; + } + + const char * k_data = k_data_row + ic * nbk1; + + vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + + s = s * scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + + s += mv; // apply mask + + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + + S = S * ms + vs; // scale and increment sum with partial sum + } + + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks) { + const float s = *sinks; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + M = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f / S; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst, VKQ32_v, DV); +} + +} // namespace + +void memcpy1d(void * dst, const void * src, int64_t size) { + size_t byte_size_all = size; + size_t vlen = __riscv_vlenb() * 8; + if (vlen == 256) { + // 1024 bytes + __asm__ volatile( + // + "srli t0, %[size], 10 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v16, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v24, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v16, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v24, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 1023 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1"); + } else if (vlen == 1024) { + // 2048 bytes + __asm__ volatile( + // + "srli t0, %[size], 11 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "addi t2, %[s], 1024 \n\t" + "addi t3, %[d], 1024 \n\t" + "li t5, 2048 \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t5 \n\t" + "vle8.v v8, (t2) \n\t" + "add t2, t2, t5 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t5 \n\t" + "vse8.v v8, (t3) \n\t" + "add t3, t3, t5 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 2047 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m2, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t3", "t5"); + } else { + __asm__ volatile( + // + "add t1, %[size], zero \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t4", "t3"); + } +} + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size) { + for (int64_t i = 0; i < tile_rows; ++i) { + memcpy1d((char *) dst + i * dst_stride, (const char *) src + i * src_stride, size); + } +} + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + // parallelize by q rows using ggml_vec_dot_f32 + + float scale = *((float *) dst->op_params + 0); + float max_bias = *((float *) dst->op_params + 1); + float logit_softcap = *((float *) dst->op_params + 2); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int KV_row_size = DK * sizeof(_Float16) + DV * sizeof(_Float16); + + int ith = params->ith; + int ir_step = 1; + for (int ir = ir0; ir < ir1; ir += ir_step) { + // q indices + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + const int iq3_1 = (ir + 1) / (neq2 * neq1); + const int iq2_1 = (ir + 1 - iq3_1 * neq2 * neq1) / neq1; + const int iq1_1 = (ir + 1 - iq3_1 * neq2 * neq1 - iq2_1 * neq1); + + const int iq3_2 = (ir + 2) / (neq2 * neq1); + const int iq2_2 = (ir + 2 - iq3_2 * neq2 * neq1) / neq1; + const int iq1_2 = (ir + 2 - iq3_2 * neq2 * neq1 - iq2_2 * neq1); + + const int iq3_3 = (ir + 3) / (neq2 * neq1); + const int iq2_3 = (ir + 3 - iq3_3 * neq2 * neq1) / neq1; + const int iq1_3 = (ir + 3 - iq3_3 * neq2 * neq1 - iq2_3 * neq1); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + const ggml_fp16_t * mp = + mask ? (ggml_fp16_t *) ((char *) mask->data + iq1 * mask->nb[1] + (iq2 % mask->ne[2]) * mask->nb[2] + + (iq3 % mask->ne[3]) * mask->nb[3]) : + NULL; + + const bool mp_equal_2 = iq1_1 == iq1 && (iq2 % mask->ne[2]) == (iq2_1 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_1 % mask->ne[3]); + + const bool mp_equal_4 = mp_equal_2 && iq1_2 == iq1 && (iq2 % mask->ne[2]) == (iq2_2 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_2 % mask->ne[3]) && iq1_3 == iq1 && + (iq2 % mask->ne[2]) == (iq2_3 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_3 % mask->ne[3]); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + const int ik3_1 = iq3_1 / rk3; + const int ik2_1 = iq2_1 / rk2; + + const int ik3_2 = iq3_2 / rk3; + const int ik2_2 = iq2_2 / rk2; + + const int ik3_3 = iq3_3 / rk3; + const int ik2_3 = iq2_3 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const int iv3_1 = iq3_1 / rv3; + const int iv2_1 = iq2_1 / rv2; + + const int iv3_2 = iq3_2 / rv3; + const int iv2_2 = iq2_2 / rv2; + + const int iv3_3 = iq3_3 / rv3; + const int iv2_3 = iq2_3 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + + std::array pq_buffer; + std::array sinks_buffer; + std::array dst_buffer; + + if (tcm_buffer != nullptr && 4 * KV_row_size < tcm_buffer_size && ir < (ir1 - 3) && mp_equal_4 && + ik3_3 == ik3 && ik2_3 == ik2 && iv3_3 == iv3 && iv2_3 == iv2 && ik3_2 == ik3 && ik2_2 == ik2 && + iv3_2 == iv3 && iv2_2 == iv2 && ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 4; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + pq_buffer[2] = (float *) ((char *) q->data + (iq1_2 * nbq1 + iq2_2 * nbq2 + iq3_2 * nbq3)); + pq_buffer[3] = (float *) ((char *) q->data + (iq1_3 * nbq1 + iq2_3 * nbq2 + iq3_3 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + sinks_buffer[2] = sinks ? ((float *) ((char *) sinks->data)) + iq2_2 : nullptr; + sinks_buffer[3] = sinks ? ((float *) ((char *) sinks->data)) + iq2_3 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + dst_buffer[2] = (float *) ((char *) dst->data + (iq3_2 * ne2 * ne1 + iq2_2 + iq1_2 * ne1) * nb1); + dst_buffer[3] = (float *) ((char *) dst->data + (iq3_3 * ne2 * ne1 + iq2_3 + iq1_3 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<4>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else if (tcm_buffer != nullptr && 2 * KV_row_size < tcm_buffer_size && ir < (ir1 - 1) && mp_equal_2 && + ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 2; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<2>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else { + ir_step = 1; + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1( // + pq, // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks ? ((float *) ((char *) sinks->data)) + h : nullptr, // + (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK); + } + } +} + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + float * param_list = (float *) dst->op_params; + float scale = param_list[0]; + float max_bias = param_list[1]; + float logit_softcap = param_list[2]; + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + // Per-thread scratch layout: + // Q_f32: Q_TILE_SZ * DK + // KQ: Q_TILE_SZ * KV_TILE_SZ + // mask32: Q_TILE_SZ * KV_TILE_SZ + // VKQ32: Q_TILE_SZ * DV + // V32: KV_TILE_SZ * DV + // K_f32: DK * KV_TILE_SZ (transposed K tile) + float * base = (float *) params->wdata + ith * (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + + KV_TILE_SZ * DV + KV_TILE_SZ * DK + CACHE_LINE_SIZE_F32); + const size_t base_size = + (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + KV_TILE_SZ * DV + KV_TILE_SZ * DK) * + sizeof(float) + + CACHE_LINE_SIZE_F32; + + if (base_size <= tcm_buffer_size && tcm_buffer != nullptr) { + base = (float *) tcm_buffer; + } + + float S_M_Buf[Q_TILE_SZ * 2]; // buffer to hold S, M, bias for one tile to reduce register pressure in main loop + float * S = S_M_Buf; + float * M = S_M_Buf + Q_TILE_SZ; + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int) (ir1 - ir), (int) (neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + for (int i = 0; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + float * Q_f32 = base; + float * KQ = (float *) ((char *) base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; + _Float16 * Q_f16 = (_Float16 *) Q_f32; + _Float16 * V_f16 = (_Float16 *) V32; + _Float16 * K_f16 = (_Float16 *) K_f32; + + rvv_zero_f32(VKQ32, Q_TILE_SZ * DV); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + if (kv_type == GGML_TYPE_F16) { + rvv_pack_f32_as_scaled_f16((uint8_t *) Q_f16, DK * sizeof(_Float16), (uint8_t *) pq, nbq1, tile_rows, DK, + scale); + } else { + memcpy2d(Q_f32, DK * sizeof(float), pq, nbq1, tile_rows, DK * sizeof(float)); + } + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int) std::min((int64_t) KV_TILE_SZ, nek1 - ic); + + rvv_zero_f32(K_f32, DK * KV_TILE_SZ); + rvv_zero_f32(V32, KV_TILE_SZ * DV); + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + const ggml_fp16_t * mp_row = + (const ggml_fp16_t *) ((const char *) mask->data + iq1 * mask->nb[1] + + (iq2 % mask->ne[2]) * mask->nb[2] + (iq3 % mask->ne[3]) * mask->nb[3]); + rvv_pack_scaled_f16_as_f32(mask32, KV_TILE_SZ * sizeof(float), mp_row + ic, mask->nb[1], tile_rows, + kv_tile, slope); + + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = 0; tk < kv_tile; tk++) { + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + + if (can_skip) { + continue; + } + } + + if (kv_type == GGML_TYPE_F16) { + rvv_transposed_s16_mn_to_nm((int8_t *) K_f16, KV_TILE_SZ * sizeof(_Float16), + (int8_t *) k->data + ic * nbk1 + ik2 * nbk2 + ik3 * nbk3, nbk1, kv_tile, + DK); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + rvv_qk_dot_tile_f16_x4(KQ + (tq + 0) * KV_TILE_SZ, KQ + (tq + 1) * KV_TILE_SZ, + KQ + (tq + 2) * KV_TILE_SZ, KQ + (tq + 3) * KV_TILE_SZ, + Q_f16 + (tq + 0) * DK, Q_f16 + (tq + 1) * DK, Q_f16 + (tq + 2) * DK, + Q_f16 + (tq + 3) * DK, K_f16, DK, kv_tile); + } + for (; tq < tile_rows; ++tq) { + rvv_qk_dot_tile_f16_x1(KQ + tq * KV_TILE_SZ, Q_f16 + tq * DK, K_f16, DK, kv_tile); + } + } else { + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *) k->data + (ic + tk) * nbk1 + ik2 * nbk2 + ik3 * nbk3; + float * k_col = K_f32 + tk; + const float * k_src = (const float *) k_data; + for (int64_t dk = 0; dk < DK; ++dk) { + k_col[dk * KV_TILE_SZ] = k_src[dk]; + } + } + + for (int tq = 0; tq < tile_rows; ++tq) { + rvv_qk_dot_tile(KQ + tq * KV_TILE_SZ, Q_f32 + tq * DK, K_f32, DK, KV_TILE_SZ, scale); + } + } + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + } + + if (logit_softcap != 0.0f) { + rvv_softcap_tanh_inplace_f32(KQ, KV_TILE_SZ, tile_rows, KV_TILE_SZ, logit_softcap); + } + + if (mask) { + rvv_add_inplace_f32(KQ, KV_TILE_SZ, mask32, KV_TILE_SZ, tile_rows, KV_TILE_SZ); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < tile_rows; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + const float tile_max = rvv_max_f32(kq_row, KV_TILE_SZ); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + S[tq] *= ms; + } + M[tq] = Mnew; + + S[tq] += rvv_softmax_exp_inplace_f32(kq_row, KV_TILE_SZ, Mnew); + } + + // Pack V as contiguous [KV_TILE_SZ][DV]. + if (kv_type == GGML_TYPE_F16) { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V_f16, DV * sizeof(_Float16), v_data, nbv1, kv_tile, DV * sizeof(_Float16)); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + if (skip[tq + 0] || skip[tq + 1] || skip[tq + 2] || skip[tq + 3]) { + for (int i = 0; i < 4; ++i) { + if (!skip[tq + i]) { + rvv_pv_accumulate_f16_x1(VKQ32 + (tq + i) * DV, KQ + (tq + i) * KV_TILE_SZ, V_f16, + KV_TILE_SZ, DV); + } + } + continue; + } + + rvv_pv_accumulate_f16_x4(VKQ32 + (tq + 0) * DV, VKQ32 + (tq + 1) * DV, VKQ32 + (tq + 2) * DV, + VKQ32 + (tq + 3) * DV, KQ + (tq + 0) * KV_TILE_SZ, + KQ + (tq + 1) * KV_TILE_SZ, KQ + (tq + 2) * KV_TILE_SZ, + KQ + (tq + 3) * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + for (; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate_f16_x1(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + } + } else { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V32, DV * sizeof(float), v_data, nbv1, kv_tile, DV * sizeof(float)); + + for (int tq = 0; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V32, KV_TILE_SZ, DV); + } + } + } + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *) ((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + } else { + vs = expf(s - M[tq]); + } + + float S_temp = S[tq] * ms + vs; + S[tq] = S_temp == 0.0f ? 0.0f : 1.0f / S_temp; + } + } else { + for (int tq = 0; tq < tile_rows; tq++) { + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + S[tq] = S_inv; + } + } + + float * dst_ptr = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + (iq1) *ne1) * nb1); + rvv_pack_scaled_f32_as_f32(dst_ptr, nb1 * ne1, VKQ32, DV * sizeof(float), tile_rows, DV, S); + + ir += tile_rows; + } +} + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template +void quantize_a_nrow_i8_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk])); + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + int16_t a_sum = 0; + for (size_t bk = 0; bk < blk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row] = -a_sum; + } + } +} + +template +void quantize_a_nrow_i8_hp_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + const size_t subblk_count = blk_len / k_subblk_len; + + GGML_ASSERT(blk_len == 256); + + float scale_temp[8] = { 0.0f }; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * MB_ROWS; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + + float scale_avg = 0.0f; + for (size_t kk = 0; kk < subblk_count; kk++) { + float max_abs_a = 0.0f; + for (size_t row = 0; row < MB_ROWS; row++) { + for (size_t bk = 0; bk < k_subblk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk + kk * k_subblk_len])); + } + } + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + float scale_factor = 1.0f / scale_avg; + + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + scale_avg_ptr[0] = scale_avg; + + for (size_t kk = 0; kk < subblk_count; kk++) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * MB_ROWS); + + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + const float rep_scale_a = 1.0f / scale_temp[kk]; + + for (size_t row = 0; row < MB_ROWS; row++) { + int16_t a_sum = 0; + for (size_t bk = 0; bk < k_subblk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk + kk * k_subblk_len] * rep_scale_a), + -128.0f, 127.0f)); + quant_a_blk[row * k_subblk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * subblk_count + kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + } + } + } +} + +template +void quantize_a_nrow_i8k_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_sum_size = 256 / 16; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * a_sum_size * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_a = 0.0f; + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + float ax = std::abs(a_ptr[row * count_k + k + bk]); + if (ax > max_abs_a) { + max_abs_a = ax; + max_a = a_ptr[row * count_k + k + bk]; + } + } + + if (!max_abs_a) { + scale_a_ptr[row] = 0; + for (size_t bki = 0; bki < a_sum_size; bki++) { + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + quant_a_blk[row * blk_len + bk] = 0; + } + a_sum_ptr[row * a_sum_size + bki] = 0; + } + continue; + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + for (size_t bki = 0; bki < a_sum_size; bki++) { + int16_t a_sum = 0; + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * a_sum_size + bki] = -a_sum; + } + } + } +} + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } else { + quantize_a_nrow_i8_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + mi * count_k + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false); + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * 4; + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + vfloat32m1_t v_a0_abs = __riscv_vfabs_v_f32m1(v_a0, vl); + vfloat32m1_t v_a1_abs = __riscv_vfabs_v_f32m1(v_a1, vl); + vfloat32m1_t v_a2_abs = __riscv_vfabs_v_f32m1(v_a2, vl); + vfloat32m1_t v_a3_abs = __riscv_vfabs_v_f32m1(v_a3, vl); + + vfloat32m1_t v_max_abs = __riscv_vfmax_vv_f32m1(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a0_scale = __riscv_vfmul_vf_f32m1(v_a0, rep_scale_a, vl); + vfloat32m1_t v_a1_scale = __riscv_vfmul_vf_f32m1(v_a1, rep_scale_a, vl); + vfloat32m1_t v_a2_scale = __riscv_vfmul_vf_f32m1(v_a2, rep_scale_a, vl); + vfloat32m1_t v_a3_scale = __riscv_vfmul_vf_f32m1(v_a3, rep_scale_a, vl); + vint16mf2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a0_scale, vl); + vint16mf2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a1_scale, vl); + vint16mf2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a2_scale, vl); + vint16mf2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a3_scale, vl); + vint8mf4_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a0_quant, vl); + vint8mf4_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a1_quant, vl); + vint8mf4_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a2_quant, vl); + vint8mf4_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + vfloat32m4_t v_a0_abs = __riscv_vfabs_v_f32m4(v_a0, vl); + vfloat32m4_t v_a1_abs = __riscv_vfabs_v_f32m4(v_a1, vl); + vfloat32m4_t v_a2_abs = __riscv_vfabs_v_f32m4(v_a2, vl); + vfloat32m4_t v_a3_abs = __riscv_vfabs_v_f32m4(v_a3, vl); + + vfloat32m4_t v_max_abs = __riscv_vfmax_vv_f32m4(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a0_scale = __riscv_vfmul_vf_f32m4(v_a0, rep_scale_a, vl); + vfloat32m4_t v_a1_scale = __riscv_vfmul_vf_f32m4(v_a1, rep_scale_a, vl); + vfloat32m4_t v_a2_scale = __riscv_vfmul_vf_f32m4(v_a2, rep_scale_a, vl); + vfloat32m4_t v_a3_scale = __riscv_vfmul_vf_f32m4(v_a3, rep_scale_a, vl); + vint16m2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16m2(v_a0_scale, vl); + vint16m2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16m2(v_a1_scale, vl); + vint16m2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16m2(v_a2_scale, vl); + vint16m2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16m2(v_a3_scale, vl); + vint8m1_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a0_quant, vl); + vint8m1_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a1_quant, vl); + vint8m1_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a2_quant, vl); + vint8m1_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits, can process 32 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + // vlen = 256 bits, can process 8 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8k_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_nrow_block_stride = a_blk_stride * 4; + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else if (vlenb == 32) { + // vlen = 256 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else { + quantize_a_nrow_i8k_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = src0->ne[2] * src0->ne[3]; + int64_t m = src0->ne[1]; + int64_t n = src0->ne[0]; + + int64_t batch_stride = src0->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = n_src_stride * m; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = dst->ne[2] * dst->ne[3]; + int64_t n = dst->ne[1]; + int64_t m = dst->ne[0]; + + int64_t batch_stride = dst->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = dst->nb[1]; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + auto src0_rows = ggml_nrows(src0); + auto src1_rows = ggml_nrows(src1); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(T)); + GGML_ASSERT(nb00 == sizeof(T)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + auto compute_func_vv = [&](int64_t blk_len, int64_t r, T * src0_ptr, T * src1_ptr, T * dst_ptr) { + int64_t idx = 0; + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfadd_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfadd_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfsub_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfsub_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfmul_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfmul_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfdiv_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfdiv_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + }; + + if (src0_rows == src1_rows && src0_rows == 1 && ne00 == ne10) { + int64_t task_per_thread = (ne00 + nth - 1) / nth; + int64_t task_begin = ith * task_per_thread; + int64_t task_end = std::min((ith + 1) * task_per_thread, ne00); + + T * dst_ptr = ((T *) dst->data) + task_begin; + T * src0_ptr = ((T *) src0->data) + task_begin; + T * src1_ptr = ((T *) src1->data) + task_begin; + + compute_func_vv(task_end - task_begin, 0, src0_ptr, src1_ptr, dst_ptr); + } else if (ne10 > 1) { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + // src1 is broadcastable across src0 and dst in i1, i2, i3 + for (int64_t r = 0; r < ne00; r += ne10) { + compute_func_vv(ne10, r, src0_ptr, src1_ptr, dst_ptr); + } + } + } else { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + T rhs_scalar = src1_ptr[0]; + int64_t blk_len = ne00; + int64_t r = 0; + + for (size_t vl; blk_len > 0; blk_len -= vl, r += vl) { + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfadd_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfadd_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfsub_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfsub_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfmul_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfmul_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfdiv_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfdiv_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + int64_t n_task = ne01 * ne02 * ne03; + int64_t task_per_thread = (n_task + nth - 1) / nth; + int64_t ir_start = ith * task_per_thread; + int64_t ir_end = std::min(ir_start + task_per_thread, n_task); + + for (int64_t ir = ir_start; ir < ir_end; ir++) { + const int64_t i3 = ir / (ne02 * ne01); + const int64_t i2 = (ir - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (ir - i3 * ne02 * ne01 - i2 * ne01); + + T * src_row = (T *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + T * dst_row = (T *) ((char *) op->data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + float row_sum = 0; + + if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const float * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t vec = __riscv_vle32_v_f32m4(p_data, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e16m2(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const _Float16 * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e16m2(length); + vfloat16m2_t vec_f16 = __riscv_vle16_v_f16m2(p_data, vl); + vfloat32m4_t vec_f32 = __riscv_vfwcvt_f_f_v_f32m4(vec_f16, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec_f32, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else { + GGML_ABORT("fatal error"); + } + + dst_row[0] = row_sum; + } +} + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t nrows = ggml_nrows(src0); + int64_t nrows_per_thread = (nrows + nth - 1) / nth; + int64_t ir_start = ith * nrows_per_thread; + int64_t ir_end = std::min(ir_start + nrows_per_thread, nrows); + + if (src0->ne[0] == 1) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + T src_scalar = src_row[0]; + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vmv_v_x_i32m4(src_scalar, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vmv_v_x_i16m4(src_scalar, vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else if (src0->ne[0] == dst->ne[0]) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_row + idx, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_row + idx), vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else { + GGML_ABORT("fatal error"); + } +} + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int64_t total_batches = ne2 * ne3; + const int64_t batches_per_thread = (total_batches + nth - 1) / nth; + const int64_t batch_start = ith * batches_per_thread; + const int64_t batch_end = std::min(batch_start + batches_per_thread, total_batches); + + for (int64_t b = batch_start; b < batch_end; b++) { + const int64_t i3 = b / ne2; + const int64_t i2 = b % ne2; + + T * src_base = (T *) ((char *) src0->data + i2 * src0->nb[2] + i3 * src0->nb[3]); + T * dst_batch = (T *) ((char *) dst->data + i2 * dst->nb[2] + i3 * dst->nb[3]); + + for (int64_t i1 = 0; i1 < ne1; i1++) { + T * dst_ptr = (T *) ((char *) dst_batch + i1 * dst->nb[1]); + int64_t length = ne0; + int64_t idx = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_base + idx, vl); + __riscv_vse32_v_i32m4(dst_ptr + idx, vec, vl); + idx += vl; + length -= vl; + } else if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_base + idx), vl); + __riscv_vse16_v_i16m4((dst_ptr + idx), vec, vl); + idx += vl; + length -= vl; + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(op) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + // rows per thread + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i / (ne11 * ne10); + const int64_t i11 = (i - i12 * ne11 * ne10) / ne10; + const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + memcpy1d(((char *) dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3) + cr0 * sizeof(T), + ((char *) src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03) + cr0 * sizeof(T), + (cr1 - cr0) * sizeof(T)); + } +} + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim == 0 && nb0 == sizeof(float) && nb1 == sizeof(float) * (ne00 + ne10)); + + const int64_t nr = ggml_nrows(dst); + const int64_t nc = ne0; + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + int64_t o[4] = { 0, 0, 0, 0 }; + o[dim] = src0->ne[dim]; + const float * x; + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i3 = i / (ne02 * ne01); + const int64_t i2 = (i - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (i - i3 * ne02 * ne01 - i2 * ne01); + + for (int i0 = cr0; i0 < cr1; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *) src0->data + (i0) *nb00 + (i1) *nb01 + (i2) *nb02 + (i3) *nb03); + } else { + x = (const float *) ((const char *) src1->data + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); + } + + float * y = (float *) ((char *) dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + + *y = *x; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows<_Float16>(const ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +} // namespace spacemit_kernels::rvv diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.h b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h new file mode 100644 index 00000000000..edddf957c21 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h @@ -0,0 +1,95 @@ +#pragma once + +#include "ggml-cpu-impl.h" + +#include +#include +#include +#include + +namespace spacemit_kernels { + +constexpr auto div_round_up(auto up, auto down) { + return (up + down - 1) / down; +} + +// Q8 Blk [f32] [s16] [int8 * blk_len] +// Q8 Blk N [f32 * N] [s16 * N] [int8 * blk_len * N] +constexpr size_t q8_blk_size(size_t blk_len, bool with_blk_sum = false) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + (with_blk_sum ? sizeof(int16_t) : 0); + return blk_size; +} + +// Q8 HP row block: K is split into K32 subblocks. +// Each subblock stores [f32 scale] [int8 * 32], with an optional fp16 sum trailer per subblock. +constexpr size_t q8_hp_blk_size(size_t blk_len, bool with_blk_sum = false, bool with_blk_scale = false) { + const size_t subblk_count = div_round_up(blk_len, size_t(32)); + const size_t blk_size = blk_len * sizeof(int8_t) + subblk_count * sizeof(_Float16) + + (with_blk_sum ? subblk_count * sizeof(_Float16) : 0) + + (with_blk_scale ? sizeof(_Float16) : 0); + return blk_size; +} + +// Q8K Blk [f32] [s16 * (blk_len / 16)] [int8 * blk_len] +// Q8K Blk N [f32 * N] [s16 * (blk_len / 16) * N] [int8 * blk_len * N] +constexpr size_t q8k_blk_size(size_t blk_len) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + sizeof(int16_t) * blk_len / 16; + return blk_size; +} + +using quantize_a_row_def = std::function; + +namespace rvv { +void memcpy1d(void * dst, const void * src, int64_t size); + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size); + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op); + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op); + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +} // namespace rvv + +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/spine_barrier.h b/ggml/src/ggml-cpu/spacemit/spine_barrier.h new file mode 100644 index 00000000000..f897dad4b8a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_barrier.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#define SPINE_CACHE_LINE 64 +#define SPINE_CACHE_ALIGN __attribute__((aligned(SPINE_CACHE_LINE))) + +struct spine_barrier_t { + SPINE_CACHE_ALIGN std::atomic pending_; + SPINE_CACHE_ALIGN std::atomic rounds_; + SPINE_CACHE_ALIGN int64_t total_; +}; + +inline void spine_barrier_wait(spine_barrier_t * b) { + auto cur_round = b->rounds_.load(std::memory_order_acquire); + auto cnt = --b->pending_; + if (cnt == 0) { + b->pending_.store(b->total_); + b->rounds_.store(cur_round + 1); + } else { + while (cur_round == b->rounds_.load(std::memory_order_relaxed)) { + __asm__ volatile("pause " ::: "memory"); + } + } +} + +inline void spine_barrier_init(spine_barrier_t * b, int num_barriers, uint64_t thread_count) { + for (int i = 0; i < num_barriers; i++) { + b[i].total_ = thread_count; + b[i].pending_.store(thread_count); + b[i].rounds_.store(0); + } +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp new file mode 100644 index 00000000000..1409423b145 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp @@ -0,0 +1,760 @@ +#include "spine_mem_pool.h" + +#include "common.h" +#include "ime_env.h" +#include "spine_tcm.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +namespace { + +constexpr size_t SPINE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull * 1024ull; +constexpr size_t SPINE_SHARE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull; +constexpr size_t SPINE_MEM_POOL_1G_REGION_SIZE = 1ull << 30; +constexpr uint64_t HUGETLB_1G_FLAG_REQUIRE_PUD = 1ull << 0; +constexpr char SPINE_MEM_POOL_HUGETLB_1G_DEV[] = "/dev/hugetlb_1g"; +constexpr char SPINE_MEM_POOL_TCM_SYNC_MEM_DEV[] = "/dev/tcm_sync_mem"; + +struct hugetlb_1g_region { + uint64_t size{ 0 }; + uint64_t dma_addr{ 0 }; + uint64_t flags{ 0 }; + uint64_t reserved{ 0 }; +}; + +#define HUGETLB_1G_IOC_MAGIC 'M' +#define HUGETLB_1G_IOC_ALLOC _IOWR(HUGETLB_1G_IOC_MAGIC, 0x00, struct hugetlb_1g_region) +#define HUGETLB_1G_IOC_FREE _IO(HUGETLB_1G_IOC_MAGIC, 0x01) + +struct free_block { + size_t offset{ 0 }; + size_t size{ 0 }; +}; + +struct pool_chunk { + uint8_t * base{ nullptr }; + size_t size{ 0 }; + int fd{ -1 }; + std::vector free_blocks; +}; + +struct pool_allocation { + void * chunk_base{ nullptr }; + size_t chunk_size{ 0 }; + void * base{ nullptr }; + size_t size{ 0 }; +}; + +bool is_power_of_two(size_t value) { + return value != 0 && (value & (value - 1)) == 0; +} + +bool align_up(size_t value, size_t alignment, size_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const size_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const size_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +bool align_up_uintptr(uintptr_t value, size_t alignment, uintptr_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const uintptr_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const uintptr_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +class spine_mem_pool_manager { + public: + explicit spine_mem_pool_manager(size_t default_chunk_size) : default_chunk_size_(default_chunk_size) {} + + virtual ~spine_mem_pool_manager() = default; + + void * alloc(size_t size, size_t alignment) { + if (size == 0 || !is_power_of_two(alignment)) { + return nullptr; + } + + size_t aligned_size = 0; + if (!align_up(size, alignment, &aligned_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: align_up failed for size %zu alignment %zu\n", __func__, size, + alignment); + return nullptr; + } + + pool_allocation allocation; + + std::lock_guard lock(mutex_); + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + if (!add_chunk_locked(aligned_size, alignment)) { + return nullptr; + } + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation retry failed for size %zu alignment %zu\n", + __func__, aligned_size, alignment); + return nullptr; + } + } + + try { + const auto [allocation_it, inserted] = allocations_.emplace(allocation.base, allocation); + if (!inserted) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: duplicate allocation key %p\n", __func__, allocation.base); + rollback_allocation_locked(allocation); + return nullptr; + } + } catch (const std::bad_alloc &) { + rollback_allocation_locked(allocation); + throw; + } + + return allocation.base; + } + + void free(void * base) { + if (base == nullptr) { + return; + } + + std::lock_guard lock(mutex_); + + auto allocation_it = allocations_.find(base); + if (allocation_it == allocations_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown allocation %p\n", __func__, base); + return; + } + + pool_allocation allocation = allocation_it->second; + allocations_.erase(allocation_it); + + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown chunk for allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p out of chunk range %p..%p\n", __func__, + allocation.base, chunk_base, chunk_base + chunk_it->size); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p size %zu exceeds chunk size %zu\n", __func__, + allocation.base, allocation.size, chunk_it->size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + protected: + void release_chunks() { + std::lock_guard lock(mutex_); + + allocations_.clear(); + for (auto & chunk : chunks_) { + dealloc_chunk(&chunk); + } + chunks_.clear(); + } + + size_t default_chunk_size() const { return default_chunk_size_; } + + static void clear_chunk(pool_chunk * chunk) { + chunk->base = nullptr; + chunk->size = 0; + chunk->fd = -1; + chunk->free_blocks.clear(); + } + + virtual bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) = 0; + virtual void dealloc_chunk(pool_chunk * chunk) = 0; + + private: + struct alloc_candidate { + size_t chunk_index{ 0 }; + size_t block_index{ 0 }; + size_t aligned_offset{ 0 }; + uintptr_t address{ std::numeric_limits::max() }; + bool valid{ false }; + }; + + std::vector::iterator find_chunk_locked(const pool_allocation & allocation) { + return std::find_if(chunks_.begin(), chunks_.end(), [&](const pool_chunk & chunk) { + return chunk.base == allocation.chunk_base && chunk.size == allocation.chunk_size; + }); + } + + bool add_chunk_locked(size_t min_size, size_t alignment) { + pool_chunk chunk; + const size_t chunk_request = default_chunk_size_ == 0 ? min_size : std::max(min_size, default_chunk_size_); + void * hint_addr = nullptr; + + for (const auto & existing_chunk : chunks_) { + auto * chunk_end = existing_chunk.base + existing_chunk.size; + if (hint_addr == nullptr || chunk_end > hint_addr) { + hint_addr = chunk_end; + } + } + + if (!alloc_chunk(chunk_request, alignment, hint_addr, &chunk)) { + return false; + } + + if (chunk.base == nullptr || chunk.size < min_size) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: invalid chunk returned for request size %zu, chunk_base=%p chunk_size=%zu\n", + __func__, min_size, chunk.base, chunk.size); + dealloc_chunk(&chunk); + return false; + } + + try { + chunk.free_blocks.push_back({ 0, chunk.size }); + chunks_.push_back(std::move(chunk)); + } catch (const std::bad_alloc &) { + dealloc_chunk(&chunk); + throw; + } + + return true; + } + + void rollback_allocation_locked(const pool_allocation & allocation) { + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, owning chunk not found\n", + __func__, allocation.base); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, chunk range is invalid\n", + __func__, allocation.base); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + bool try_alloc_locked(size_t size, size_t alignment, pool_allocation * allocation) { + alloc_candidate best; + + for (size_t chunk_index = 0; chunk_index < chunks_.size(); ++chunk_index) { + const auto & chunk = chunks_[chunk_index]; + for (size_t block_index = 0; block_index < chunk.free_blocks.size(); ++block_index) { + const auto & block = chunk.free_blocks[block_index]; + + uintptr_t aligned_addr = 0; + const auto block_addr = reinterpret_cast(chunk.base + block.offset); + if (!align_up_uintptr(block_addr, alignment, &aligned_addr)) { + continue; + } + + if (aligned_addr < block_addr) { + continue; + } + + const size_t aligned_offset = block.offset + static_cast(aligned_addr - block_addr); + const size_t padding = aligned_offset - block.offset; + if (padding > block.size || size > block.size - padding) { + continue; + } + + if (!best.valid || aligned_addr < best.address) { + best.chunk_index = chunk_index; + best.block_index = block_index; + best.aligned_offset = aligned_offset; + best.address = aligned_addr; + best.valid = true; + } + } + } + + if (!best.valid) { + return false; + } + + auto & chunk = chunks_[best.chunk_index]; + const free_block block = chunk.free_blocks[best.block_index]; + const size_t padding = best.aligned_offset - block.offset; + const size_t alloc_end = best.aligned_offset + size; + const size_t block_end = block.offset + block.size; + + chunk.free_blocks.erase(chunk.free_blocks.begin() + best.block_index); + auto insert_it = chunk.free_blocks.begin() + best.block_index; + if (padding != 0) { + insert_it = chunk.free_blocks.insert(insert_it, { block.offset, padding }); + ++insert_it; + } + if (alloc_end < block_end) { + chunk.free_blocks.insert(insert_it, { alloc_end, block_end - alloc_end }); + } + + allocation->chunk_base = chunk.base; + allocation->chunk_size = chunk.size; + allocation->base = chunk.base + best.aligned_offset; + allocation->size = size; + return true; + } + + void maybe_release_empty_chunk_locked(std::vector::iterator chunk_it) { + if (chunk_it->free_blocks.size() != 1) { + return; + } + + const auto & block = chunk_it->free_blocks.front(); + if (block.offset != 0 || block.size != chunk_it->size) { + return; + } + + dealloc_chunk(&*chunk_it); + chunks_.erase(chunk_it); + } + + void insert_free_block_locked(pool_chunk & chunk, free_block block) { + auto it = chunk.free_blocks.begin(); + while (it != chunk.free_blocks.end() && it->offset < block.offset) { + ++it; + } + + if (it != chunk.free_blocks.begin()) { + const auto & prev = *(it - 1); + if (prev.offset + prev.size > block.offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + } + + if (it != chunk.free_blocks.end() && block.offset + block.size > it->offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping next free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + + it = chunk.free_blocks.insert(it, block); + + if (it != chunk.free_blocks.begin()) { + auto prev = it - 1; + if (prev->offset + prev->size == it->offset) { + it->offset = prev->offset; + it->size += prev->size; + it = chunk.free_blocks.erase(prev); + } + } + + if (it + 1 != chunk.free_blocks.end() && it->offset + it->size == (it + 1)->offset) { + it->size += (it + 1)->size; + chunk.free_blocks.erase(it + 1); + } + } + + std::mutex mutex_; + std::vector chunks_; + std::unordered_map allocations_; + size_t default_chunk_size_{ 0 }; +}; + +class spine_mem_pool_posix final : public spine_mem_pool_manager { + public: + spine_mem_pool_posix() : spine_mem_pool_manager(0) {} + + ~spine_mem_pool_posix() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) hint_addr; + + const size_t alloc_alignment = std::max(alignment, sizeof(void *)); + void * base = nullptr; + const int rc = posix_memalign(&base, alloc_alignment, min_size); + if (rc != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: posix_memalign failed for size %zu alignment %zu, rc=%d\n", + __func__, min_size, alloc_alignment, rc); + return false; + } + + chunk->base = static_cast(base); + chunk->size = min_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + std::free(chunk->base); + clear_chunk(chunk); + } +}; + +class spine_mem_pool_transparent_hugepage final : public spine_mem_pool_manager { + public: + spine_mem_pool_transparent_hugepage() : spine_mem_pool_manager(SPINE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_transparent_hugepage() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + size_t chunk_size = 0; + if (!align_up(min_size, default_chunk_size(), &chunk_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round chunk size for %zu\n", __func__, min_size); + return false; + } + + void * map_addr = mmap(hint_addr, chunk_size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for chunk size %zu, errno=%d\n", __func__, chunk_size, + errno); + return false; + } + + if (madvise(map_addr, chunk_size, MADV_HUGEPAGE) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: madvise(MADV_HUGEPAGE) failed for chunk size %zu, errno=%d\n", + __func__, chunk_size, errno); + munmap(map_addr, chunk_size); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = chunk_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for chunk %p size %zu, errno=%d\n", __func__, + chunk->base, chunk->size, errno); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_hugetlb_1g final : public spine_mem_pool_manager { + public: + spine_mem_pool_hugetlb_1g() : spine_mem_pool_manager(SPINE_MEM_POOL_1G_REGION_SIZE) {} + + ~spine_mem_pool_hugetlb_1g() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + (void) hint_addr; + + size_t region_size = 0; + if (!align_up(min_size, SPINE_MEM_POOL_1G_REGION_SIZE, ®ion_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round hugetlb_1g size for %zu\n", __func__, min_size); + return false; + } + + const int fd = open(SPINE_MEM_POOL_HUGETLB_1G_DEV, O_RDWR); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_HUGETLB_1G_DEV, errno); + return false; + } + + hugetlb_1g_region region; + region.size = region_size; + region.flags = HUGETLB_1G_FLAG_REQUIRE_PUD; + if (ioctl(fd, HUGETLB_1G_IOC_ALLOC, ®ion) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_ALLOC failed for size %zu, errno=%d\n", __func__, + region_size, errno); + close(fd); + return false; + } + + void * map_addr = mmap(nullptr, region.size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for hugetlb_1g size %llu, errno=%d\n", __func__, + static_cast(region.size), errno); + ioctl(fd, HUGETLB_1G_IOC_FREE); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = region.size; + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for hugetlb_1g chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + if (ioctl(chunk->fd, HUGETLB_1G_IOC_FREE) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_FREE failed for chunk %p, errno=%d\n", + __func__, chunk->base, errno); + } + + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_shared_mem final : public spine_mem_pool_manager { + public: + spine_mem_pool_shared_mem() : spine_mem_pool_manager(SPINE_SHARE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_shared_mem() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + if (hint_addr != nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem does not support multiple active chunks\n", __func__); + return false; + } + + if (min_size > default_chunk_size()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem request %zu exceeds chunk size %zu\n", __func__, + min_size, default_chunk_size()); + return false; + } + + const int fd = open(SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, O_RDWR | O_SYNC); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, errno); + return false; + } + + void * map_addr = mmap(nullptr, default_chunk_size(), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for %s size %zu, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, default_chunk_size(), errno); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = default_chunk_size(); + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for shared_mem chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +spine_mem_pool_manager & get_spine_mem_pool_manager() { + static std::once_flag pool_once; + static std::unique_ptr selected_pool; + static spine_mem_pool_backend selected_backend = spine_mem_pool_backend::none; + + spine_mem_pool_backend backend = global_spine_env_info.mem_backend; + if (backend == spine_mem_pool_backend::none) { + backend = spine_mem_pool_backend::transparent_hugepage; + } + + std::call_once(pool_once, [&]() { + selected_backend = backend; + + switch (selected_backend) { + case spine_mem_pool_backend::posix_memalign: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::transparent_hugepage: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::hugetlb_1g: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::none: + selected_backend = spine_mem_pool_backend::transparent_hugepage; + selected_pool = std::make_unique(); + break; + } + }); + + if (backend != selected_backend) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: mem pool backend is process-global and mutually exclusive, requested=%d but " + "selected=%d\n", + __func__, static_cast(backend), static_cast(selected_backend)); + } + + if (selected_pool) { + return *selected_pool; + } + + throw std::bad_alloc(); +} + +spine_mem_pool_manager & get_spine_mem_pool_shared_mem_manager() { + static std::once_flag shared_mem_pool_once; + static std::unique_ptr shared_mem_pool; + + std::call_once(shared_mem_pool_once, [&]() { shared_mem_pool = std::make_unique(); }); + + if (shared_mem_pool) { + return *shared_mem_pool; + } + + throw std::bad_alloc(); +} + +} // namespace + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept { + if (info == nullptr) { + return false; + } + + *info = {}; + + if (spine_tcm_open_handle(NULL) != 0 || !spine_tcm_is_available()) { + return false; + } + + spine_tcm_mem_info_t mem_info; + if (spine_tcm_mem_info(&mem_info) != 0) { + return false; + } + + info->available = true; + info->blk_size = mem_info.blk_size; + info->blk_num = mem_info.blk_num; + info->is_fake_tcm = mem_info.is_fake_tcm != 0; + return true; +} + +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept { + return spine_tcm_mem_get(cpu_id); +} + +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept { + return spine_tcm_mem_try_wait(cpu_id, 1000 * 1000); +} + +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept { + return spine_tcm_mem_release(cpu_id); +} + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating size %zu\n", __func__, size); + return nullptr; + } +} + +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_shared_mem_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating shared memory size %zu\n", __func__, size); + return nullptr; + } +} + +void spine_mem_pool_free(void * base) noexcept { + try { + get_spine_mem_pool_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing allocation %p\n", __func__, base); + } +} + +void spine_mem_pool_shared_mem_free(void * base) noexcept { + try { + get_spine_mem_pool_shared_mem_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing shared allocation %p\n", __func__, base); + } +} + +} // namespace ggml::cpu::riscv64_spacemit + +extern "C" { +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment) { + void * result = ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_alloc(size, alignment); + if (result == nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to allocate shared memory size %zu alignment %zu\n", __func__, + size, alignment); + } + return result; +} + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr) { + ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_free(ptr); +} +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h new file mode 100644 index 00000000000..8740d2c99ef --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +enum class spine_mem_pool_backend : uint8_t { + none, + posix_memalign, + transparent_hugepage, + hugetlb_1g, +}; + +struct spine_mem_pool_tcm_info { + bool available{ false }; + size_t blk_size{ 0 }; + size_t blk_num{ 0 }; + bool is_fake_tcm{ false }; +}; + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept; +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept; +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept; +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept; + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept; +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept; +void spine_mem_pool_free(void * base) noexcept; +void spine_mem_pool_shared_mem_free(void * base) noexcept; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/spine_tcm.h b/ggml/src/ggml-cpu/spacemit/spine_tcm.h new file mode 100644 index 00000000000..f300d7d5c04 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_tcm.h @@ -0,0 +1,409 @@ +#ifndef SPINE_TCM_PUBLIC_H_ +#define SPINE_TCM_PUBLIC_H_ + +/* + * spine_tcm public API + * + * Usage: + * 1. Direct link mode + * Define SPINE_TCM_DIRECT_LINK and link against libspine_tcm.so. + * + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + * + * 2. Header-only loader mode + * Include this header without linking libspine_tcm.so. The loader first + * tries to reuse a process-global spine_tcm instance and falls back to + * dlopen("libspine_tcm.so") when needed. + * + * spine_tcm_open_handle(NULL); // optional pre-bind + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + */ + +#include +#include +#include + +#if !defined(SPINE_TCM_BUILD_SHARED) && !defined(SPINE_TCM_DIRECT_LINK) +# include +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_WIN32) +# if defined(SPINE_TCM_BUILD_SHARED) +# define SPINE_TCM_API __declspec(dllexport) +# else +# define SPINE_TCM_API __declspec(dllimport) +# endif +#else +# define SPINE_TCM_API __attribute__((visibility("default"))) +#endif + +typedef struct spine_tcm_mem_info { + size_t blk_size; + size_t blk_num; + int is_fake_tcm; +} spine_tcm_mem_info_t; + +typedef struct spine_tcm_block_info { + int id; + void * va; + size_t size; + uint64_t phys_addr; + uint64_t cpu_affinity_mask; + int owner_tid; + int is_acquired; +} spine_tcm_block_info_t; + +/* Shared-library runtime ABI exported by libspine_tcm.so. */ +SPINE_TCM_API const char * spine_tcm_runtime_version(void); +SPINE_TCM_API int spine_tcm_runtime_is_available(void); +SPINE_TCM_API int spine_tcm_runtime_layout_info(spine_tcm_mem_info_t * info); +SPINE_TCM_API int spine_tcm_runtime_mem_info(int id, spine_tcm_block_info_t * info); +SPINE_TCM_API void * spine_tcm_runtime_mem_get(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_free(int id); +SPINE_TCM_API void * spine_tcm_runtime_mem_try_wait(int id, size_t timeout_us); +SPINE_TCM_API int spine_tcm_runtime_mem_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_force_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_query(int id); + +#if defined(SPINE_TCM_DIRECT_LINK) +/* Optional no-op in direct-link mode. */ +static inline int spine_tcm_open_handle(const char * so_path) { + (void) so_path; + return 0; +} + +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + return spine_tcm_runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + return spine_tcm_runtime_layout_info(info); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + return spine_tcm_runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + return spine_tcm_runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + return spine_tcm_runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + return spine_tcm_runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + return spine_tcm_runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + return spine_tcm_runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + return spine_tcm_runtime_mem_query(id); +} +#elif !defined(SPINE_TCM_BUILD_SHARED) +typedef struct spine_tcm_handle { + void * module_handle; + int use_global_scope; + int owns_module_handle; + const char * (*runtime_version)(void); + int (*runtime_is_available)(void); + int (*runtime_layout_info)(spine_tcm_mem_info_t * info); + int (*runtime_mem_info)(int id, spine_tcm_block_info_t * info); + void * (*runtime_mem_get)(int id); + int (*runtime_mem_free)(int id); + void * (*runtime_mem_try_wait)(int id, size_t over_time); + int (*runtime_mem_release)(int id); + int (*runtime_mem_force_release)(int id); + int (*runtime_mem_query)(int id); +} spine_tcm_handle_t; + +static inline spine_tcm_handle_t * spine_tcm_default_handle(void) { + static spine_tcm_handle_t handle = { 0 }; + return &handle; +} + +static inline void spine_tcm_handle_reset(spine_tcm_handle_t * handle) { + if (handle != NULL) { + memset(handle, 0, sizeof(*handle)); + } +} + +static inline int spine_tcm_handle_bind(spine_tcm_handle_t * handle) { + void * symbol_scope = handle->use_global_scope ? RTLD_DEFAULT : handle->module_handle; + + handle->runtime_version = (const char * (*) (void) ) dlsym(symbol_scope, "spine_tcm_runtime_version"); + handle->runtime_is_available = (int (*)(void)) dlsym(symbol_scope, "spine_tcm_runtime_is_available"); + handle->runtime_layout_info = + (int (*)(spine_tcm_mem_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_layout_info"); + handle->runtime_mem_info = + (int (*)(int, spine_tcm_block_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_mem_info"); + handle->runtime_mem_get = (void * (*) (int) ) dlsym(symbol_scope, "spine_tcm_runtime_mem_get"); + handle->runtime_mem_free = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_free"); + handle->runtime_mem_try_wait = (void * (*) (int, size_t)) dlsym(symbol_scope, "spine_tcm_runtime_mem_try_wait"); + handle->runtime_mem_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_release"); + handle->runtime_mem_force_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_force_release"); + handle->runtime_mem_query = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_query"); + + return handle->runtime_version != NULL && handle->runtime_is_available != NULL && + handle->runtime_layout_info != NULL && handle->runtime_mem_info != NULL && + handle->runtime_mem_get != NULL && handle->runtime_mem_free != NULL && + handle->runtime_mem_try_wait != NULL && handle->runtime_mem_release != NULL && + handle->runtime_mem_force_release != NULL && handle->runtime_mem_query != NULL ? + 0 : + -1; +} + +/* + * Try to bind against an already-loaded process-global spine_tcm instance. + * The shared library exports spine_tcm_runtime_marker only for this probe. + */ +static inline int spine_tcm_try_bind_global(spine_tcm_handle_t * handle) { + if (dlsym(RTLD_DEFAULT, "spine_tcm_runtime_marker") == NULL) { + return -1; + } + + handle->use_global_scope = 1; + return spine_tcm_handle_bind(handle); +} + +/* + * Optional pre-bind entry point. + * + * Behavior: + * - Reuses an already-loaded global spine_tcm instance when available. + * - Otherwise loads the shared library from so_path or the default soname. + * - Repeated calls are safe and return 0 after the first successful bind. + */ +static inline int spine_tcm_open_handle(const char * so_path) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + const char * library = (so_path != NULL && so_path[0] != '\0') ? so_path : "libspine_tcm.so"; + + if (resolved->module_handle != NULL || resolved->use_global_scope) { + return 0; + } + + if (spine_tcm_try_bind_global(resolved) == 0) { + return 0; + } + + spine_tcm_handle_reset(resolved); + + resolved->module_handle = dlopen(library, RTLD_LAZY | RTLD_GLOBAL); + resolved->owns_module_handle = resolved->module_handle != NULL ? 1 : 0; + + if (resolved->module_handle == NULL) { + spine_tcm_handle_reset(resolved); + return -1; + } + + if (spine_tcm_handle_bind(resolved) != 0) { + if (resolved->owns_module_handle) { + dlclose(resolved->module_handle); + } + spine_tcm_handle_reset(resolved); + return -1; + } + + return 0; +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_is_available == NULL) { + return 0; + } + + return resolved->runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_layout_info == NULL) { + return -1; + } + + return resolved->runtime_layout_info(info); +} + +static inline const char * spine_tcm_version(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_version == NULL) { + return "unknown"; + } + + return resolved->runtime_version(); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_info == NULL) { + return -1; + } + + return resolved->runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_get == NULL) { + return NULL; + } + + return resolved->runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_free == NULL) { + return -1; + } + + return resolved->runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_try_wait == NULL) { + return NULL; + } + + return resolved->runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_release == NULL) { + return -1; + } + + return resolved->runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || + resolved->runtime_mem_force_release == NULL) { + return -1; + } + + return resolved->runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_query == NULL) { + return -1; + } + + return resolved->runtime_mem_query(id); +} +#else +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} +#endif + +#define SPINE_TCM_VERSION (spine_tcm_version()) + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index b54d4a6b107..d3953eee962 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND) # 80 == Ampere, asynchronous data loading, faster tensor core instructions # 86 == RTX 3000, needs CUDA v11.1 # 89 == RTX 4000, needs CUDA v11.8 + # 90 == Hopper H100/200, needs CUDA v11.8 # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores # # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run @@ -33,7 +34,7 @@ if (CUDAToolkit_FOUND) list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") - list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) + list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-virtual) endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index adb4d5f0cb9..c25f42b32bb 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -2,6 +2,9 @@ #include #include +template +using type_for_index = T; + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -52,6 +55,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const int s12, const int s13, src1_ptrs... src1s) { + ggml_cuda_pdl_lc(); const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); @@ -72,6 +76,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; + ggml_cuda_pdl_sync(); for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); @@ -141,6 +146,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); + ggml_cuda_pdl_sync(); float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); @@ -282,35 +288,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast_unravel<<>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast_unravel...>, launch_params, src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, /*s0,*/ s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast_unravel - <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, - ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /*s0,*/ s1, s2, s3, - s00, s01, s02, s03, - s10, s11, s12, s13); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast<<>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast...>, launch_params, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, /*s0,*/ s1, s2, s3, - s00 ,s01, s02, s03, - s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast<<>>( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /*s0,*/ s1, s2, s3, s00, s01, s02, s03, - s10, s11, s12, s13); + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } } } @@ -333,6 +328,7 @@ static __global__ void k_repeat_back( } T sum = 0; + ggml_cuda_pdl_sync(); for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 10817505d9f..e54ecb29308 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -5,6 +5,7 @@ #include "ggml-cuda.h" #include +#include #include #if defined(GGML_USE_HIP) @@ -27,6 +28,7 @@ #include #include #include +#include #include #if defined(GGML_USE_HIP) @@ -50,6 +52,7 @@ #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 #define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_HOPPER 900 // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms #define GGML_CUDA_CC_BLACKWELL 1200 @@ -107,6 +110,24 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA. +// __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 +# define GGML_CUDA_USE_PDL +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 + +static __device__ __forceinline__ void ggml_cuda_pdl_sync() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaGridDependencySynchronize(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + +static __device__ __forceinline__ void ggml_cuda_pdl_lc() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaTriggerProgrammaticLaunchCompletion(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + #ifdef __CUDA_ARCH_LIST__ constexpr bool ggml_cuda_has_arch_impl(int) { return false; @@ -165,6 +186,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) + #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) static const char * cublas_get_error_str(const cublasStatus_t err) { return cublasGetStatusString(err); @@ -1487,3 +1509,68 @@ struct ggml_cuda_mm_fusion_args_device { const void * gate_bias = nullptr; ggml_glu_op glu_op; }; + +struct ggml_cuda_kernel_launch_params { + dim3 block_nums; + dim3 block_dims; + size_t shmem; + cudaStream_t stream; + + // size_t shmem + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {} + + // Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings. + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {} +}; + +#if defined(GGML_CUDA_USE_PDL) +struct ggml_cuda_pdl_config { + cudaLaunchAttribute attr; + cudaLaunchConfig_t cfg; + + ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) { + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = 1; + + cfg = {}; + cfg.gridDim = params.block_nums; + cfg.blockDim = params.block_dims; + cfg.dynamicSmemBytes = params.shmem; + cfg.stream = params.stream; + cfg.attrs = &attr; + cfg.numAttrs = 1; + } + + // Delete due to &attr + ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; + +}; +#endif //defined(GGML_CUDA_USE_PDL) + + +template +static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { +#if defined(GGML_CUDA_USE_PDL) + + static const bool env_pdl_enabled = []() { + const char * env = getenv("GGML_CUDA_PDL"); + return env == nullptr || std::atoi(env) != 0; + }(); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + if (env_pdl_enabled && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_HOPPER) { + auto pdl_cfg = ggml_cuda_pdl_config(launch_params); + + CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward(args)... )); + return; + } +#endif //defined(GGML_CUDA_USE_PDL) + + kernel<<>>(std::forward(args)... ); + CUDA_CHECK(cudaGetLastError()); +} + diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index 102f944f924..adba4d522a4 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -15,6 +15,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont const int64_t n = ne0 * ne1 * ne2; + ggml_cuda_pdl_sync(); for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) { if constexpr (dim == 0) { const int64_t row = i / ne0; @@ -64,8 +65,8 @@ static void concat_f32_cuda(const float * x, const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; if (dim == 0) { - concat_f32_cont<0> - <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index d208acf2d5f..121472ec228 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -16,6 +16,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13) { + ggml_cuda_pdl_lc(); const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -36,6 +37,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; + ggml_cuda_pdl_sync(); cpy_1(cx + x_offset, cdst + dst_offset); } @@ -59,6 +61,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; int cur_tile_buf = 0; + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { @@ -142,6 +145,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -168,6 +172,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -182,6 +187,7 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const const src_t * x = (const src_t *) cx; dst_t * dst = (dst_t *) cdst; + ggml_cuda_pdl_sync(); dst[i] = ggml_cuda_cast(x[i]); } @@ -192,8 +198,8 @@ cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar_contiguous<<>> - (cx, cdst, ne); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_contiguous, launch_params, cx, cdst, ne); } template @@ -223,13 +229,15 @@ static void ggml_cpy_scalar_cuda( GGML_ASSERT(grid_z < USHRT_MAX); dim3 dimGrid(grid_x, grid_y, grid_z); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - cpy_scalar_transpose<<>> - (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_transpose, launch_params, + cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar>, launch_params, + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index beeb5238946..debcb6e5447 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -636,6 +636,7 @@ static __global__ void flash_attn_mask_to_KV_max( if (tid < WARP_SIZE) { buf_iw[tid] = 1; } + ggml_cuda_pdl_sync(); __syncthreads(); int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; @@ -687,6 +688,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform( const uint3 fd_iter_j_z, const uint3 fd_iter_j) { constexpr int ncols = ncols1*ncols2; + ggml_cuda_pdl_lc(); const int tile_idx = blockIdx.x; // One block per output tile. const int j = blockIdx.y; @@ -718,6 +720,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform( dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + ggml_cuda_pdl_sync(); // Load the partial result that needs a fixup float dst_val = *dst; float max_val; @@ -809,6 +812,7 @@ static __global__ void flash_attn_stream_k_fixup_general( float dst_val = 0.0f; float max_val = 0.0f; float rowsum = 0.0f; + ggml_cuda_pdl_sync(); { dst_val = *dst; @@ -867,6 +871,7 @@ static __global__ void flash_attn_combine_results( const float2 * __restrict__ VKQ_meta, float * __restrict__ dst, const int parallel_blocks) { + ggml_cuda_pdl_lc(); // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y @@ -890,6 +895,7 @@ static __global__ void flash_attn_combine_results( __builtin_assume(tid < D); extern __shared__ float2 meta[]; + ggml_cuda_pdl_sync(); for (int i = tid; i < 2*parallel_blocks; i += D) { ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } @@ -1146,7 +1152,9 @@ void launch_fattn( const uint3 ne01 = init_fastdiv_values(Q->ne[1]); GGML_ASSERT(block_dim.x % warp_size == 0); - fattn_kernel<<>>( + + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); + ggml_cuda_kernel_launch(fattn_kernel, launch_params, (const char *) Q->data, K_data, V_data, @@ -1176,9 +1184,9 @@ void launch_fattn( const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2}; - flash_attn_stream_k_fixup_uniform - <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_uniform, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk, gqa_ratio, bpt, fd0, fd1, fd2); } else if (ntiles_dst % blocks_num.x != 0) { @@ -1193,9 +1201,9 @@ void launch_fattn( const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup_general - <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_general, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], gqa_ratio, total_work, fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k); } @@ -1204,9 +1212,9 @@ void launch_fattn( const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream); + ggml_cuda_kernel_launch(flash_attn_combine_results, launch_params, + dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 43e22c5e5ee..4871b90df86 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -125,61 +125,107 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) { - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 1, true); - // TODO tune specifically for RDNA - return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 64, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 64, 96, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 128, 2, 32, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 128, 2, 32, 160, 128, 128, 1, true); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) { - // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async). - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true); - - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 1, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 4, 64, 32, 32, 32, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40, 40, 40, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48, 48, 48, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56, 56, 56, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64, 64, 64, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 512, 1, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 512, 1, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 64, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 64, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 64, 160, 128, 128, 1, true); - // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy - // compile-time static_asserts even though the kernel guard prevents runtime execution. - // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility. - return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); } static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { @@ -510,7 +556,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; @@ -712,6 +758,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; + + // The mask is stored as 16 bit half values, loading them as 32 bit half2 values is preferred in terms of speed. + // However, this is not possible for RDNA3 where 2 consecutive l indices are not consecutive in the mask memory layout. +#ifdef RDNA3 +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_j(l); + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l)) / ncols2; + + KQ_C[i00/(np*T_C_KQ::J)].x[l] += __half2float(tile_mask[j*(nbatch_fa + 8) + i]); + } +#else #pragma unroll for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { const int i = (i0 + T_C_KQ::get_j(l0)) / 2; @@ -721,6 +779,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } +#endif // RDNA3 } } @@ -827,13 +886,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - const half2 KQ_max_scale_h2 = make_half2( - KQ_max_scale[0], KQ_max_scale[0]); + if constexpr (std::is_same_v) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll - for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } } } #else // Volta @@ -901,9 +970,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += T_A_VKQ::I) { static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { @@ -912,15 +980,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); if constexpr (T_B_KQ::I == 8) { - mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); } else { // Wide version of VKQ_C is column-major. #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) // AMD matrix C is column-major. - mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); #else // swap A and B for CUDA. - mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], B[k00/(np*T_A_VKQ::J)], A); #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } @@ -953,11 +1021,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } #if defined(TURING_MMA_AVAILABLE) -template struct mma_tile_sizes { +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major using T_C_KQ = tile<16, 16, float>; // column-major @@ -965,7 +1033,7 @@ template struct mma_tile_sizes { using T_B_VKQ = tile<16, 8, half2>; // column-major using T_C_VKQ = tile<16, 8, half2>; // column-major }; -template<> struct mma_tile_sizes<8> { +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile< 8, 8, half2>; // column-major using T_C_KQ = tile<16, 8, float>; // row-major @@ -973,8 +1041,60 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; -#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) -template struct mma_tile_sizes { +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#else +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED>; // column-major +}; +template struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major using T_C_KQ = tile<16, 16, float>; // column-major @@ -983,7 +1103,7 @@ template struct mma_tile_sizes { using T_C_VKQ = tile<16, 8, half2>; // column-major }; #else // Volta -template struct mma_tile_sizes { +template struct mma_tile_sizes { using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major @@ -1018,17 +1138,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int zt_gqa, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; - using T_A_KQ = typename mma_tile_sizes::T_A_KQ; - using T_B_KQ = typename mma_tile_sizes::T_B_KQ; - using T_C_KQ = typename mma_tile_sizes::T_C_KQ; - using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; - using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; - using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; + using T_A_KQ = typename mma_tile_sizes::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); @@ -1061,6 +1181,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + T_C_VKQ VKQ_C[DV % 32 != 0 ? DV/T_C_VKQ::J : DV/(2*T_C_VKQ::J)]; #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta @@ -1269,12 +1391,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); + if constexpr (std::is_same_v) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll - for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } } } #else // Volta @@ -1433,6 +1566,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { if constexpr (cols_per_warp == 8) { + static_assert(std::is_same_v, "bad VKQ type"); const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { @@ -1447,14 +1581,45 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } else { const int j0 = threadIdx.y*cols_per_warp; + if constexpr (std::is_same_v) { + if constexpr (T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR) { #pragma unroll - for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - const int j = j0 + T_C_VKQ::get_i(l); - const int k = k1 + T_C_VKQ::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + } + } + } else { + static_assert(T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR_SCRAMBLED, "bad T_C_VKQ data layout"); + using T_C_VKQ_us = tile; // us == unscrambled +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + const T_C_VKQ_us VKQ_C_us = unscramble(VKQ_C[(k00 + k1)/T_C_VKQ::J]); +#pragma unroll + for (int l = 0; l < T_C_VKQ_us::ne; ++l) { + const int j = j0 + T_C_VKQ_us::get_i(l); + const int k = k1 + T_C_VKQ_us::get_j(l); + + tile_Q[j*tile_stride + k] = VKQ_C_us.x[l]; + } + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); + half * tile_Q_h = (half *) tile_Q; +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J/2) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = 2*k1 + T_C_VKQ::get_j(l); + + tile_Q_h[j*(2*tile_stride) + k] = VKQ_C[(k00 + k1)/(T_C_VKQ::J/2)].x[l]; + } } } } @@ -1532,7 +1697,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } template @@ -1559,7 +1724,8 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) + ggml_cuda_pdl_sync(); // TODO optimize placement +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { @@ -1585,14 +1751,14 @@ static __global__ void flash_attn_ext_f16( #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING #if defined(AMD_WMMA_AVAILABLE) - if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) { + if (ncols1*ncols2 < 16 || ncols2 == 1 || DKQ > 128) { NO_DEVICE_CODE; return; } #endif // defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) - if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) { + if (ncols1*ncols2 < 16 || DKQ > 256) { NO_DEVICE_CODE; return; } @@ -1715,7 +1881,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) } template diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 7b0a5e5cf49..fac76f13593 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -894,6 +894,8 @@ static __global__ void flash_attn_tile( } float KQ_sum[cpw] = {0.0f}; + ggml_cuda_pdl_sync(); + // Load Q data, convert to FP16 if fast: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index f0bd42a5761..b0a6cf67f1a 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -40,6 +40,7 @@ static __global__ void flash_attn_ext_vec( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { + ggml_cuda_pdl_lc(); #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: @@ -136,6 +137,8 @@ static __global__ void flash_attn_ext_vec( #endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + + ggml_cuda_pdl_sync(); if constexpr (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index f19defbff93..4b6f6501094 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -86,6 +86,7 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + ggml_cuda_pdl_sync(); const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index e045b04f727..1c7777e8a71 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -19,13 +19,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } if constexpr (ncols2 <= 16) { - if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { + if (Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } } - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) { + if (Q->ne[1] <= 32/ncols2 || (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) || + (GGML_CUDA_CC_IS_AMD(cc) && DKQ > 256)) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -477,12 +478,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_MMA_F16; } + const int ncols2_max = Q->ne[0] == 320 ? 32 : ((Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8); + int gqa_ratio_eff = 1; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { - int gqa_ratio_eff = 1; - const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8; - while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { - gqa_ratio_eff *= 2; - } if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -500,41 +502,22 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_WMMA_F16; } - if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) { - if (can_use_vector_kernel) { - if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { - if (Q->ne[1] == 1) { - if (!gqa_opt_applies) { - return BEST_FATTN_KERNEL_VEC; - } - } - } else { - if (Q->ne[1] <= 2) { - return BEST_FATTN_KERNEL_VEC; - } - } + // AMD MFMA needs a certain minimum batch size to outscale the tile kernel for large head sizes. + if ((amd_mfma_available(cc) && Q->ne[0] <= 256) && Q->ne[0] != 40 && Q->ne[0] != 72) { + if ((Q->ne[0] <= 64 && Q->ne[1] * gqa_ratio_eff > 8)) { + return BEST_FATTN_KERNEL_MMA_F16; } - int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; - while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { - gqa_ratio_eff *= 2; + if ((Q->ne[0] <= 128 && Q->ne[1] * gqa_ratio_eff > 16)) { + return BEST_FATTN_KERNEL_MMA_F16; } - if (Q->ne[1] * gqa_ratio_eff <= 8) { - return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized. + if ((Q->ne[0] <= 256 && Q->ne[1] * gqa_ratio_eff > 64)) { + return BEST_FATTN_KERNEL_MMA_F16; } - return BEST_FATTN_KERNEL_MMA_F16; } - // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { - const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); - // MMA vs tile crossover benchmarked on MI300X @ d32768: - // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) - // hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%) - if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) { - return BEST_FATTN_KERNEL_MMA_F16; - } - // Fall through to tile kernel for small effective batch sizes. + // AMD WMMA is always faster than the tile kernel if the full tile width of 16 can be utilized. + if ((amd_wmma_available(cc) && gqa_opt_applies && Q->ne[0] <= 128) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[1] * gqa_ratio_eff > 8) { + return BEST_FATTN_KERNEL_MMA_F16; } // If there are no tensor cores available, use the generic tile kernel: diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec7317..018d5d37d47 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,7 @@ #include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -23,7 +24,8 @@ gated_delta_net_cuda(const float * q, int64_t sb3, const uint3 neqk1_magic, const uint3 rq3_magic, - float scale) { + float scale, + int K) { const uint32_t h_idx = blockIdx.x; const uint32_t sequence = blockIdx.y; // each warp owns one column, using warp-level primitives to reduce across rows @@ -37,9 +39,13 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset + col * S_v; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -48,12 +54,17 @@ gated_delta_net_cuda(const float * q, float s_shard[rows_per_lane]; // state is stored transposed: M[col][i] = S[i][col], row col is contiguous + ggml_cuda_pdl_sync(); #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -135,17 +146,30 @@ gated_delta_net_cuda(const float * q, } attn_data += S_v * H; + + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } } - // Write state back to global memory (transposed layout) + if constexpr (!keep_rs_t) { #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -155,7 +179,7 @@ static void launch_gated_delta_net( int64_t sv1, int64_t sv2, int64_t sv3, int64_t sb1, int64_t sb2, int64_t sb3, int64_t neqk1, int64_t rq3, - float scale, cudaStream_t stream) { + float scale, int K, cudaStream_t stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const int num_warps = 4; @@ -167,31 +191,32 @@ static void launch_gated_delta_net( int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } default: @@ -261,13 +286,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 36b840e8148..457b695eb2a 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -11,6 +11,7 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. @@ -48,6 +49,8 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + ggml_cuda_pdl_lc(); + ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. @@ -83,6 +86,7 @@ static __global__ void k_get_rows_back_float( float sum = 0.0f; + ggml_cuda_pdl_sync(); for (int64_t i = 0; i < nrows_grad; ++i) { if (rows[i] != dst_row) { continue; @@ -156,7 +160,8 @@ static void get_rows_cuda_float( GGML_ASSERT(ne11 <= std::numeric_limits::max() / ne12); const uint3 ne12_fdv = init_fastdiv_values(ne12); - k_get_rows_float<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{block_nums, block_dims, 0, stream}; + ggml_cuda_kernel_launch(k_get_rows_float, launch_params, src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 49af5389957..a8f6046e46d 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -67,9 +67,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132 if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 79bb2934c5f..8d7c69dc3e8 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -80,6 +80,7 @@ namespace ggml_cuda_mma { DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3. DATA_LAYOUT_J_MAJOR_MIRRORED = 30, + DATA_LAYOUT_I_MAJOR_SCRAMBLED = 40, // Scrambled matrix C for faster transposition (RDNA4/CDNA), convert to float to unscramble. }; // Implemented mma combinations are: // - (I_MAJOR, I_MAJOR) -> I_MAJOR @@ -312,13 +313,19 @@ namespace ggml_cuda_mma { half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -327,7 +334,15 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return ne * (threadIdx.x / 16) + l; + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { +#ifdef RDNA3 + return l*2 + (threadIdx.x / 16); +#else + return (threadIdx.x / 16) * ne + l; +#endif // RDNA3 + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -338,13 +353,19 @@ namespace ggml_cuda_mma { half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -353,7 +374,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return ne * (threadIdx.x / 16) + l; + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -516,12 +541,15 @@ namespace ggml_cuda_mma { if (I == 16 && J == 16) return true; if (I == 16 && J == 8) return true; if (I == 16 && J == 4) return true; + if (I == 32 && J == 8) return true; return false; } - static __device__ __forceinline__ int get_i(const int /*l*/) { - if constexpr (supported()) { + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16) { return threadIdx.x % 16; + } else if constexpr (I == 32) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -529,8 +557,10 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (supported()) { + if constexpr (I == 16) { return l; + } else if constexpr (I == 32) { + return l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -644,6 +674,40 @@ namespace ggml_cuda_mma { } }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_SCRAMBLED; + + static constexpr int ne = I * J / ggml_cuda_get_physical_warp_size(); + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + }; + + static __device__ __forceinline__ tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> unscramble(const tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & t) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> ret; +#pragma unroll + for (int l0 = 0; l0 < t.ne/2; ++l0) { + ret.x[2*l0 + 0] = __lows2half2(t.x[l0], t.x[l0 + t.ne/2]); + ret.x[2*l0 + 1] = __highs2half2(t.x[l0], t.x[l0 + t.ne/2]); + } + return ret; +#else + NO_DEVICE_CODE; + GGML_UNUSED(t); +#endif // defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + } + #if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { @@ -660,6 +724,21 @@ namespace ggml_cuda_mma { ret.x[0] = ggml_cuda_movmatrix(t.x[0]); ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + return ret; + } +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + static __device__ __forceinline__ tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> get_half2( + const tile<16, 16, float, DATA_LAYOUT_I_MAJOR> & tile_float) { + tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> ret; +#pragma unroll + for (int l = 0; l < tile_float.ne; ++l) { + float tmp[2]; + int i = threadIdx.x / 16; + tmp[i] = tile_float.x[l]; + i ^= 1; + tmp[i] = __shfl_xor_sync(0xFFFFFFFF, tile_float.x[l], 16, WARP_SIZE); + ret.x[l] = make_half2(tmp[0], tmp[1]); + } return ret; } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) @@ -802,21 +881,35 @@ namespace ggml_cuda_mma { #endif // defined(VOLTA_MMA_AVAILABLE) } - template + template static __device__ __forceinline__ void load_ldmatrix_trans( - tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { + tile & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE + static_assert(I == 16, "bad tile width"); + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); #elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - half * xh = (half *) t.x; + static_assert(dl == DATA_LAYOUT_I_MAJOR || dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + if constexpr (I == 32) { #pragma unroll - for (int l = 0; l < t.ne; ++l) { - xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; - xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + for (int l0 = 0; l0 < t.ne/2; ++l0) { + const half2 tmp0 = xs0[(2*t.get_j(l0) + 0)*stride + t.get_i(l0)/2]; + const half2 tmp1 = xs0[(2*t.get_j(l0) + 1)*stride + t.get_i(l0)/2]; + + t.x[l0] = __lows2half2(tmp0, tmp1); + t.x[l0 + t.ne/2] = __highs2half2(tmp0, tmp1); + } + } else { + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } } #else GGML_UNUSED_VARS(t, xs0, stride); @@ -972,6 +1065,20 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR> & B) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 8, half2> * D16 = (tile<16, 8, half2> *) &D; + const tile<16, 8, half2> * A16 = (const tile<16, 8, half2> *) &A; + mma(D16[0], A16[0], B); + mma(D16[1], A16[1], B); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) && defined(RDNA4) + } + template static __device__ __forceinline__ void mma( tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) { @@ -1296,6 +1403,22 @@ namespace ggml_cuda_mma { #endif // defined(VOLTA_MMA_AVAILABLE) } + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { +#if defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; + halfx16_t * xD = (halfx16_t *) D.x; + const halfx16_t * xA = (const halfx16_t *) A.x; + const halfx16_t * xB = (const halfx16_t *) B.x; + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[0], xB[0], xD[0], /*opsel =*/ 0); + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[1], xB[0], xD[0], /*opsel =*/ 1); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index d9147202429..09d95f309b4 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -21,6 +21,7 @@ static __global__ void mul_mat_vec_f( int channel_y; int sample_dst; + ggml_cuda_pdl_sync(); if constexpr (is_multi_token_id) { // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case token_idx = blockIdx.z; @@ -298,6 +299,7 @@ static __global__ void mul_mat_vec_f( static_assert(std::is_same_v, "unsupported type"); } + ggml_cuda_pdl_lc(); #pragma unroll for (int j = 0; j < ncols_dst; ++j) { sumf[j] = warp_reduce_sum(sumf[j]); @@ -382,11 +384,13 @@ static void mul_mat_vec_f_switch_fusion( const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) { + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, nbytes_shared, stream}; + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; @@ -395,8 +399,8 @@ static void mul_mat_vec_f_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index da48f313a38..13b8b855282 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -359,7 +359,9 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: + return 8; case GGML_TYPE_Q6_K: + return 2; case GGML_TYPE_IQ4_NL: return 8; default: @@ -422,6 +424,7 @@ static __global__ void mul_mat_vec_q( uint32_t channel_y; uint32_t sample_dst; + ggml_cuda_pdl_sync(); channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; sample_dst = blockIdx.z; @@ -681,8 +684,9 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; @@ -691,8 +695,9 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index ef98f675aa7..09d9f3a7d62 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -18,6 +18,7 @@ static __global__ void norm_f32( float2 mean_var = make_float2(0.0f, 0.0f); + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; mean_var.x += xi; @@ -46,6 +47,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int j = start; j < end; j += block_size) { tmp += x[j]; } @@ -95,6 +97,7 @@ static __global__ void rms_norm_f32(const float * x, const uint3 add_nrows_packed = make_uint3(0, 0, 0), const uint3 add_nchannels_packed = make_uint3(0, 0, 0), const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { + ggml_cuda_pdl_lc(); const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -124,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x, float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; @@ -163,6 +167,7 @@ static __global__ void rms_norm_back_f32( float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xfi = xf[col]; sum_xx += xfi * xfi; @@ -253,6 +258,7 @@ static __global__ void l2_norm_f32( float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; @@ -261,6 +267,7 @@ static __global__ void l2_norm_f32( // sum up partial sums extern __shared__ float s_sum[]; tmp = block_reduce(tmp, s_sum); + ggml_cuda_pdl_lc(); // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html const float scale = rsqrtf(fmaxf(tmp, eps * eps)); @@ -300,10 +307,19 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = {blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, false>, launch_params, + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, false>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } @@ -346,14 +362,20 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } else { const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); @@ -367,14 +389,16 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims,block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, @@ -399,10 +423,12 @@ static void l2_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - l2_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32<1024>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 52f664719ae..49516965cad 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -6,6 +6,7 @@ static __global__ void quantize_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne0, const uint32_t ne1, const uint3 ne2) { + ggml_cuda_pdl_lc(); const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { @@ -28,6 +29,7 @@ static __global__ void quantize_q8_1( const int64_t ib = i_cont / QK8_1; // block index const int64_t iqs = i_cont % QK8_1; // quant index + ggml_cuda_pdl_sync(); const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -196,6 +198,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, const int64_t i2 = blockIdx.z % ne2; const int64_t i3 = blockIdx.z / ne2; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -288,6 +291,7 @@ static __global__ void quantize_mmq_q8_1( const int64_t i3 = blockIdx.z / ne2; const int64_t i00 = i0; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -378,7 +382,8 @@ void quantize_row_q8_1_cuda( const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, block_size, 0, stream); + ggml_cuda_kernel_launch(quantize_q8_1, launch_params, x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); GGML_UNUSED(type_src0); } diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index de240fd4413..5895d3bf8e5 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -10,6 +10,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r const int num_unroll = 8; float temp[num_unroll]; float sum_temp[num_unroll] = { 0.0f }; + + ggml_cuda_pdl_sync(); for (int i = col; i < ncols;) { for (int j = 0; j < num_unroll; ++j) { if (i < ncols) { diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 45a49a5dc2a..e20a5cb6bed 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -134,6 +134,7 @@ static __global__ void rope_neox(const T * x, const float * freq_factors, const int64_t * row_indices, const int set_rows_stride) { + ggml_cuda_pdl_lc(); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne00) { @@ -148,6 +149,7 @@ static __global__ void rope_neox(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. @@ -216,6 +218,7 @@ static __global__ void rope_multi(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; @@ -300,6 +303,7 @@ static __global__ void rope_vision(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); const int sect_dims = sections.v[0] + sections.v[1]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; @@ -399,13 +403,14 @@ static void rope_neox_cuda(const T * x, const dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f / n_dims); + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, 0, stream}; if (freq_factors == nullptr) { - rope_neox<<>>( + ggml_cuda_kernel_launch(rope_neox, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { - rope_neox<<>>( + ggml_cuda_kernel_launch(rope_neox, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } @@ -443,11 +448,13 @@ static void rope_multi_cuda(const T * x, const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { - rope_multi<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { - rope_multi<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 0ddeff6a175..7b2e59a4383 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -3,9 +3,11 @@ #define MAX_GRIDDIM_X 0x7FFFFFFF static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + ggml_cuda_pdl_lc(); int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + ggml_cuda_pdl_sync(); for (int64_t i = tid; i < nelements; i += stride) { dst[i] = scale * x[i] + bias; } @@ -13,7 +15,8 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, bias, nelements); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(scale_f32, launch_params, x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 631de7e8fa5..e14f96b824c 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -53,6 +53,7 @@ static __global__ void k_set_rows_quant(const float * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; @@ -157,7 +158,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + ggml_cuda_pdl_lc(); const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; @@ -203,9 +206,11 @@ static void set_rows_cuda( const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); - k_set_rows<<>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, - s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, - ne11_fd, ne12_fd); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_size, block_size, 0, stream); + ggml_cuda_kernel_launch(k_set_rows, launch_params, + src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, + s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, + ne11_fd, ne12_fd); } } diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index 40dfe45d65c..9f0fa1051cf 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -1,18 +1,21 @@ #include "softcap.cuh" static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = tanhf(scale * x[i]) * softcap; } static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE; - softcap_f32<<>>(x, dst, scale, softcap, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(softcap_f32, launch_params, x, dst, scale, softcap, k); } // fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4841389fbc8..48787b4b890 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "ssm-conv.cuh" #include "unary.cuh" @@ -7,6 +8,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { + ggml_cuda_pdl_lc(); GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; @@ -23,6 +25,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float float x[d_conv] = { 0.0f }; float w[d_conv] = { 0.0f }; + ggml_cuda_pdl_sync(); #pragma unroll for (size_t j = 0; j < d_conv; j++) { w[j] = w_block[tid * stride_w + j]; @@ -128,8 +131,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, - dst, dst_nb0, dst_nb1, dst_nb2, n_t); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_conv_f32, launch_params, src0, src1, bias, src0_nb0, src0_nb1, + src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); @@ -140,11 +144,12 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa }; switch (nc) { - case 3: launch_kernel(std::integral_constant{}); break; - case 4: launch_kernel(std::integral_constant{}); break; - case 5: launch_kernel(std::integral_constant{}); break; - case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); + case 3: launch_kernel(std::integral_constant{}); break; + case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; + case 9: launch_kernel(std::integral_constant{}); break; + case 15: launch_kernel(std::integral_constant{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now."); } } diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index c1d4e2bc8df..412980376ac 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -26,6 +26,7 @@ __global__ void __launch_bounds__(splitD, 1) const int64_t s_off, const int64_t d_inner, const int64_t L_param) { const size_t L = L_template == 0 ? L_param : L_template; + ggml_cuda_pdl_sync(); const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float)); const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); @@ -135,6 +136,7 @@ __global__ void __launch_bounds__(d_state, 1) const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); + ggml_cuda_pdl_sync(); // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); @@ -206,7 +208,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<128/WARP_SIZE, 128><<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<128/WARP_SIZE, 128>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -215,7 +218,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<256/WARP_SIZE, 256><<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<256/WARP_SIZE, 256>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -231,58 +235,59 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream); switch (n_tok) { case 1: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 2: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 3: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 4: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 5: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 6: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 7: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 8: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; default: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 4025771aadb..0003658ca95 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -7,10 +7,12 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int const dim3 block_nums(nrows, 1, 1); if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, x, dst, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, x, dst, ncols); } } @@ -34,10 +36,12 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if ((nrows / nsm) < 2) { // Increase num threads to 512 for small nrows to better hide the latency const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } else { // Enough active SMs to hide latency, use smaller blocks to allow better scheduling const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 59ce36fb1c9..db1d39e2dc7 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -5,6 +5,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) # define CUB_TOP_K_AVAILABLE +# include using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 #endif // GGML_CUDA_USE_CUB diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 3020e5c7433..da20c9aab7c 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -105,6 +105,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * wt[i] = -INFINITY; } + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; @@ -161,6 +162,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * output_weights[i] = 0.f; } + ggml_cuda_pdl_lc(); for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; int max_expert = threadIdx.x; @@ -271,51 +273,52 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); switch (n_expert) { case 1: - topk_moe_cuda<1, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<1, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 2: - topk_moe_cuda<2, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<2, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 4: - topk_moe_cuda<4, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<4, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 8: - topk_moe_cuda<8, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<8, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 16: - topk_moe_cuda<16, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<16, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 32: - topk_moe_cuda<32, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<32, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 64: - topk_moe_cuda<64, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<64, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 128: - topk_moe_cuda<128, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<128, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 256: - topk_moe_cuda<256, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<256, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 512: - topk_moe_cuda<512, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<512, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 576: - topk_moe_cuda<576, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<576, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; default: GGML_ASSERT(false && "fatal error"); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 2aeba26f414..4cb805fa601 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -116,19 +116,22 @@ static __device__ __forceinline__ float op_trunc(float x) { template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = (T)op((float)x[i]); } template static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - unary_op_kernel<<>>(x, dst, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_op_kernel, launch_params, x, dst, k); } template @@ -258,6 +261,7 @@ void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { template static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { + ggml_cuda_pdl_lc(); const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { @@ -268,13 +272,15 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t j0 = (i / n) * o0 + (i % n); const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + ggml_cuda_pdl_sync(); dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); } template static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_gated_op_kernel, launch_params, x, g, dst, k, n, o0, o1); } template diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d1c9da8329..9db99cb0f3a 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2661,7 +2661,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess int mode = op_params[2]; - if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { + if (mode == GGML_ROPE_TYPE_VISION) { return false; } if (mode & 1) { @@ -2735,12 +2735,25 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { return false; } + if (src0->nb[0] != sizeof(float) || src1->nb[0] != sizeof(float) || dst->nb[0] != sizeof(float)) { + return false; + } + if (src0->nb[1] != src0->ne[0] * sizeof(float) || src1->nb[1] != src1->ne[0] * sizeof(float)) { + return false; + } - // TODO: add support for non-contiguous tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return true; +} + +static bool ggml_hexagon_supported_pad(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { return false; } + GGML_UNUSED(sess); return true; } @@ -2816,6 +2829,21 @@ static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { return false; } + if (dst->type != GGML_TYPE_F32) { return false; } + if (!ggml_are_same_shape(src0, dst)) { return false; } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; } + + return true; + + GGML_UNUSED(sess); +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2843,6 +2871,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_NORM: return HTP_OP_NORM; case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; case GGML_OP_SCALE: return HTP_OP_SCALE; @@ -2857,6 +2886,9 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_TRI: return HTP_OP_TRI; + case GGML_OP_PAD: return HTP_OP_PAD; + case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -3308,10 +3340,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_add_id(sess, op); break; + case GGML_OP_NORM: case GGML_OP_L2_NORM: - supp = ggml_hexagon_supported_unary(sess, op); - break; - case GGML_OP_RMS_NORM: case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); @@ -3416,6 +3446,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_solve_tri(sess, op); break; + case GGML_OP_TRI: + supp = ggml_hexagon_supported_tri(sess, op); + break; + + case GGML_OP_PAD: + supp = ggml_hexagon_supported_pad(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index bcadac11f95..36f923243cd 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -38,6 +38,7 @@ add_library(${HTP_LIB} SHARED diag-ops.c solve-tri-ops.c gated-delta-net-ops.c + pad-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index e5b9d350fd7..5c040a32224 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -88,6 +88,29 @@ static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + // Fast path: when both src0 and dst are contiguous in memory + // Replace the element-by-element loop with a single bulk HVX copy per (i03, i02) slice. + const bool src0_contig = (nb00 == ct->src0_type_size) && + (nb01 == ne00 * nb00) && + (nb02 == ne01 * nb01) && + (nb03 == ne02 * nb02); + const bool dst_contig = (nb0 == ct->dst_type_size) && + (nb1 == ne0 * nb0) && + (nb2 == ne1 * nb1) && + (nb3 == ne2 * nb2); + + if (src0_contig && dst_contig) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; + uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; + uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ct->src0_type_size; + hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ct->src0_type_size); + } + } + return; + } + // dst counters int64_t k10 = 0; int64_t i11 = 0; diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 4a4ff0b331d..9e1b778b01f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -852,9 +852,10 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. - // vror brings the target lane to lane 0, then extract + re-splat. - HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2))); - HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2))); + // vror brings the target lane to lane 0, then vdelta replicates it + // across all lanes — stays in the vector domain (no store/reload). + HVX_Vector v_m_prev0 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2)); + HVX_Vector v_m_prev1 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2)); // HVX max — both operands are splats, so result is splat of m_new. HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index e05ccfd5fc7..3ef0bcdb26d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -201,11 +201,10 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32 // Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using // full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. -// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes. -static inline void dequantize_x4x2_q4_0_x4groups_hvx( +// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. +static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt, - HVX_Vector out[4]) { + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { // Load all 128 packed bytes (4 contiguous 32-byte groups) HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); @@ -221,8 +220,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b - volatile HVX_Vector vscale = hvx_vmemu(scales_4); - + HVX_Vector vscale = hvx_vmemu(scales_4); HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); @@ -230,8 +228,9 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - out[0] = v_lo; // group0 already in [0:63] - out[1] = v_hi; // group2 already in [0:63] + HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */ + v_hi /* group2 already in [0:63] */ }; + return r; } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. @@ -292,12 +291,11 @@ static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed } // Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). -static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, +static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, bool upper_nibbles, int sub_blk_base, const HVX_Vector vlut_cvt, - mxfp4_scales_t scales, - HVX_Vector out[4]) { + mxfp4_scales_t scales) { HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; @@ -318,10 +316,8 @@ static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_12 v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - out[0] = v_lo; - out[1] = Q6_V_vror_VR(v_lo, 64); - out[2] = v_hi; - out[3] = Q6_V_vror_VR(v_hi, 64); + HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) }; + return r; } // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. @@ -372,18 +368,18 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - HVX_Vector v0[2]; const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -415,21 +411,21 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - HVX_Vector v0[4], v1[4]; - dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0); + HVX_Vector_x4 dv0, dv1; + dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); if (row1 < n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1); + dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); } else { - v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero(); } for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -612,11 +608,13 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; @@ -832,10 +830,6 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); } -// - -#define FALLBACK_TO_STANDARD 1 - // C += AB static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, @@ -861,314 +855,80 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); } - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; - } - Q6_mxmem_AR_after_hf(accum_tile, 0); - } - } -} - -static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, - float *restrict out, const float *restrict x, const uint8_t *restrict w, - int m, int k, int n, int weight_type) { - // assume k % 32 == 0 && n % 32 == 0 - const size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - const size_t vtcm_budget = ctx->vtcm_size; - - const size_t K_BLOCK_SIZE = 1024; - - // Fallback: if k doesn't need K-blocking, out-stationary has no advantage - const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; - if (k_iters_check <= 1) { - FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); - return FALLBACK_TO_STANDARD; - } - - // Dynamic M,N search via hmx_compute_chunks - const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); - const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) - + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) - const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) - + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) - const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - - // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; - // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin - const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; - const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment - - size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; - // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. - // From profiling: wt_dequant per element ≈ 1.5× activation load per element. - // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). - // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). - const size_t m_block_cost = (size_t) n * 3; - const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - m_block_cost, n_block_cost, &M_BLOCK_SIZE, - &N_BLOCK_SIZE, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } - - // Compute precise buffer sizes from searched M,N and fixed K - const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); - const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); - - const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; - if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, - vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); - return -1; - } - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); - uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); - uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); - __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, - M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - // initialize eye tile (32x32 identity matrix) - { - HVX_Vector v; - v = Q6_V_vzero(); - v = Q6_Vw_vinsert_VwR(v, 0x3c000000); - v = Q6_V_vror_VR(v, VLEN - 4); - v = Q6_Vw_vinsert_VwR(v, 0x00003c00); - for (int i = 0; i < 16; ++i) { - ((HVX_Vector *) vtcm_eye_tile)[i] = v; - v = Q6_V_vror_VR(v, VLEN - 8); - } - } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - TIMER_DEFINE(fetch); - TIMER_DEFINE(act_load); - TIMER_DEFINE(wt_dequant); - TIMER_DEFINE(core); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { - size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); - for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { - size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); - - const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); - - for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); - - TIMER_START(fetch); - // fetch activation block into VTCM - { - const float *activation_block = x + mr * k + kk; - - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_scratch1, activation_block), - k_blk_sz * sizeof(float), - k * sizeof(float), - k_blk_sz * sizeof(float), - m_blk_sz); - } - - // fetch weight block into VTCM (x4x2 sub-block: quants + scales) - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); - { - const int blk_start = kk / QK_Q4_0x4x2; - const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; - uint8_t *dst = vtcm_scratch0; - const uint8_t *src = w + nc * row_stride; - const size_t n_rows = n_blk_sz; - const size_t src_stride = row_stride; - const size_t dst_stride = sub_row_stride; - const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); - const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); - const size_t scale_off = full_qrow + blk_start * scale_blk_size; - const size_t scale_width = nb_sub * scale_blk_size; - - // 2D DMA: quants sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); - // 2D DMA: scales sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); - } - TIMER_STOP(fetch); - - TIMER_START(act_load); - // load activation block - { - dma_queue_pop(ctx->dma[0]); // wait for act DNA - transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); - } - TIMER_STOP(act_load); - - TIMER_START(wt_dequant); - // dequantize weight block - { - dma_queue_pop(ctx->dma[0]); - dma_queue_pop(ctx->dma[0]); - // vtcm_scratch0 is used to store the qweight chunk - // worker_pool_run_func already returned, so fetch is done - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, - n_blk_sz, k_blk_sz, sub_row_stride, weight_type); - } - TIMER_STOP(wt_dequant); - - // core mma - TIMER_START(core); - { - core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, - n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); - } - TIMER_STOP(core); + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } - // store output block - { - float *output_block = out + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); - } + Q6_mxmem_AR_after_hf(accum_tile, 0); } } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", - TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); -#endif - return 0; } -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, int weight_type) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { return -1; } - // for large m, k (e.g. prefill FFN Down), use out-stationary version - if (m >= 128 && k > n && n > 1024) { - int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); - if (rc != FALLBACK_TO_STANDARD) { - return rc; // 0 success, -1 error - } - FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); - // fall through to standard path - } - size_t row_stride = get_x4x2_row_stride(weight_type, k); if (row_stride == 0) { return -1; } - FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - // Only pays off when the chunker yields >=2 n-chunks, so the main loop can - // overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for - // double-buffered output and the worker-dispatch overhead are pure loss. - // Try pipeline costs first; fall back to sequential if the layout collapses - // to one n-chunk. m >= 128 floor keeps HMX utilization reasonable. - const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) - const size_t seq_per_mn = sizeof(__fp16); // O x 1 - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - bool use_pipeline = false; - - if (m >= 128) { - size_t mc = 0, nc = 0, used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && - hmx_ceil_div((size_t) n, nc) >= 2) { - m_chunk_n_rows = mc; - n_chunk_n_cols = nc; - vtcm_used = used; - use_pipeline = true; - } - } + const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - if (!use_pipeline) { - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + return -1; } - // Compute precise buffer sizes per execution path - const size_t weight_area_size = hex_align_up( - n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up( - m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); size_t scratch0_size, scratch1_size, scratch2_size; - if (use_pipeline) { - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 - } else { - scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 - scratch1_size = scratch0_size; // x4x2 DMA buf 1 - scratch2_size = 0; // unused - } + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, weight_type, use_pipeline, - m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); TIMER_DEFINE(weight_load); @@ -1178,184 +938,115 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_DEFINE(total); TIMER_START(total); - FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", - use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) + // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - if (!use_pipeline) { - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + // A --> B: vtcm_qweight, 1 buffer + // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers + // C --> D: vtcm_output0/vtcm_output1, 2 buffers - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } - TIMER_STOP(activation_load); + // Async timeline (C overlaps B+D): + // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] + // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); - } - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready - - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - - const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); - } + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - // Dequant + vscatter writes directly to [K, N] transposed tiles. - // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - } else { - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers - - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); } - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); - } - - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); - } + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); } + } - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + } - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); + + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's + // counterpart — and (i+1)%2 was last used by C_{i-1} which completed + // before C_i was submitted. + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); - } + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); } } - - hmx_queue_suspend(ctx->hmx_queue); } + hmx_queue_suspend(ctx->hmx_queue); + TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline); + FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1370,15 +1061,15 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // -static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; } -static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; } -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { const int r2 = hmx_matmul_batch_r2(params); const int r3 = hmx_matmul_batch_r3(params); @@ -1387,37 +1078,36 @@ static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_ (size_t) (dst_b3 / r3) * params->src0_nb3); } -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (const float *) ((const uint8_t *) params->activation + (size_t) dst_b2 * params->src1_nb2 + (size_t) dst_b3 * params->src1_nb3); } -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (float *) ((uint8_t *) params->dst + (size_t) dst_b2 * params->dst_nb2 + (size_t) dst_b3 * params->dst_nb3); } -static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params) { +static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_f16_f32_batched_params_t *params) { int ret = 0; for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_mat_mul_permuted_w16a32(ctx, - hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); + ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); } } return ret; } -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) { if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } if (!params->m || !params->k || !params->n) { return -1; } if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } @@ -1435,7 +1125,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (group_size <= 1) { FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } // Grouped path: reuse interleaved weight across all q_heads sharing a @@ -1464,7 +1154,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu /*m_block_cost=*/(size_t) params->n, /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads @@ -1486,7 +1176,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 @@ -1614,7 +1304,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu // -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index 1c78ffadd1c..f114edb822f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -33,14 +33,14 @@ typedef struct { size_t src1_nb3; size_t dst_nb2; size_t dst_nb3; -} hmx_matmul_w16a32_batched_params_t; +} hmx_matmul_f16_f32_batched_params_t; // HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output // act_stride: activation row stride in elements (= k for contiguous, or // nb[1]/sizeof(float) for permuted tensors like attention Q). // weight_stride: weight row stride in elements (= k for compact weights, or // nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const __fp16 *permuted_weight, @@ -48,13 +48,12 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, int act_stride, int weight_stride); -// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32. +// Batched F16 wrapper over hmx_mat_mul_f16_f32. // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params); +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL) -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, +// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 92f02eac6e3..6fe3e6c7d85 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -107,5 +107,7 @@ int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); +int op_tri(struct htp_ops_context * octx); +int op_pad(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 98db864dd42..9d905a30133 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -86,6 +86,9 @@ enum htp_op_code { HTP_OP_SOLVE_TRI, HTP_OP_L2_NORM, HTP_OP_GATED_DELTA_NET, + HTP_OP_TRI, + HTP_OP_PAD, + HTP_OP_NORM, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 883a31d6163..e8619388478 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -87,35 +87,37 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } +#if __HVX_ARCH__ >= 75 { - // Power on HMX + // Power on HMX and set HMX clock HAP_power_request_t request; memset(&request, 0, sizeof(HAP_power_request_t)); - request.type = HAP_power_set_HMX; - request.hmx.power_up = TRUE; - FARF(ALWAYS, "Powering HMX on\n"); - err = HAP_power_set((void *) &ctx, &request); + request.type = HAP_power_set_HMX_v2; + request.hmx_v2.set_power = TRUE; + request.hmx_v2.power_up = TRUE; + request.hmx_v2.set_clock = TRUE; + request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; + FARF(ALWAYS, "Setting HMX clock\n"); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "Error setting HMX clock."); return err; } } - -#if __HVX_ARCH__ >= 75 +#else { - // Set HMX clock + // Power on HMX HAP_power_request_t request; memset(&request, 0, sizeof(HAP_power_request_t)); - request.type = HAP_power_set_HMX_v2; - request.hmx_v2.set_clock = TRUE; - request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; - FARF(ALWAYS, "Setting HMX clock\n"); - err = HAP_power_set((void *) &ctx, &request); + request.type = HAP_power_set_HMX; + request.hmx.power_up = TRUE; + FARF(ALWAYS, "Powering HMX on\n"); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); + FARF(ERROR, "Error powering on HMX."); return err; } } @@ -534,6 +536,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_ADD_ID: return op_binary(octx); + case HTP_OP_NORM: case HTP_OP_RMS_NORM: case HTP_OP_SCALE: case HTP_OP_SQR: @@ -595,9 +598,15 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_SOLVE_TRI: return op_solve_tri(octx); + case HTP_OP_PAD: + return op_pad(octx); + case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); + case HTP_OP_TRI: + return op_tri(octx); + case HTP_OP_INVALID: break; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 2461ae617fa..46fc5602dc9 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -2995,7 +2995,6 @@ int op_matmul(struct htp_ops_context * octx) { // is handled by HMX itself; when M < 32 fall back to HVX. const int m_total = (int) src1->ne[1]; const int m_hmx = m_total & ~31; // 0 when M < 32 - if (m_hmx == 0) { return op_matmul_hvx(octx); } @@ -3020,7 +3019,7 @@ int op_matmul(struct htp_ops_context * octx) { if (src0->type == HTP_TYPE_F16) { if (is_batched) { - hmx_matmul_w16a32_batched_params_t batch_params = { + hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, .permuted_weight = (const __fp16 *) src0->data, @@ -3041,15 +3040,14 @@ int op_matmul(struct htp_ops_context * octx) { .dst_nb2 = dst->nb[2], .dst_nb3 = dst->nb[3], }; - ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params); + ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_mat_mul_permuted_w16a32(octx->ctx, + ret = hmx_matmul_f16_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, m_total, k, n, act_stride, wgt_stride); } } else { - ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, m_total, k, n, (int) src0->type); } diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c new file mode 100644 index 00000000000..3abc3c2ead1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -0,0 +1,545 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +/* Circular wrap: maps any integer x into [0, n) */ +static inline uint32_t wrap_around(int32_t x, uint32_t n) { + return (uint32_t)(((x % (int32_t)n) + (int32_t)n) % (int32_t)n); +} + +/* Decompose a flat dst row index into (i1, i2, i3) */ +static inline void pad_decompose_row(uint32_t ir, uint32_t ne1, uint32_t ne2, + uint32_t *i1, uint32_t *i2, uint32_t *i3) { + *i1 = ir % ne1; + *i2 = (ir / ne1) % ne2; + *i3 = ir / (ne1 * ne2); +} + +/* Return non-zero if row (i1,i2,i3) falls in the non-padded interior */ +static inline int pad_is_interior(uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t rp1, uint32_t ne1, + int32_t lp2, int32_t rp2, uint32_t ne2, + int32_t lp3, int32_t rp3, uint32_t ne3) { + return ((int32_t)i1 >= lp1 && (int32_t)i1 < (int32_t)ne1 - rp1) && + ((int32_t)i2 >= lp2 && (int32_t)i2 < (int32_t)ne2 - rp2) && + ((int32_t)i3 >= lp3 && (int32_t)i3 < (int32_t)ne3 - rp3); +} + +/* Compute the DDR src row pointer for a zero-pad interior row */ +static inline const uint8_t * pad_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + (i1 - (uint32_t)lp1) * src->nb[1] + + (i2 - (uint32_t)lp2) * src->nb[2] + + (i3 - (uint32_t)lp3) * src->nb[3]; +} + +/* Compute the DDR src row pointer for a circular row (wrap-around indexing) */ +static inline const uint8_t * pad_circ_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + wrap_around((int32_t)i1 - lp1, src->ne[1]) * src->nb[1] + + wrap_around((int32_t)i2 - lp2, src->ne[2]) * src->nb[2] + + wrap_around((int32_t)i3 - lp3, src->ne[3]) * src->nb[3]; +} + +struct htp_pad_context { + struct htp_ops_context * octx; + + int32_t lp0, rp0; + int32_t lp1, rp1; + int32_t lp2, rp2; + int32_t lp3, rp3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; + + size_t type_size; + + // Row sizes for DMA kernel (populated when VTCM is available) + size_t src_row_size; + size_t src_row_size_aligned; + size_t dst_row_size; + size_t dst_row_size_aligned; +}; + +#define htp_pad_preamble \ + const struct htp_tensor * src = octx->src[0]; \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne00 = src->ne[0]; \ + const uint32_t nb00 = src->nb[0]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const int32_t lp0 = pctx->lp0, rp0 = pctx->rp0; \ + const int32_t lp1 = pctx->lp1, rp1 = pctx->rp1; \ + const int32_t lp2 = pctx->lp2, rp2 = pctx->rp2; \ + const int32_t lp3 = pctx->lp3, rp3 = pctx->rp3; \ + \ + const size_t type_size = pctx->type_size; \ + \ + const uint32_t row_start = pctx->nrows_per_thread * ith; \ + const uint32_t row_end = MIN(row_start + pctx->nrows_per_thread, pctx->total_dst_rows); + + +#define htp_pad_dma_preamble \ + const size_t src_row_size = pctx->src_row_size; \ + const size_t src_row_size_aligned = pctx->src_row_size_aligned; \ + const size_t dst_row_size = pctx->dst_row_size; \ + const size_t dst_row_size_aligned = pctx->dst_row_size_aligned; \ + \ + uint8_t * src_spad_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; \ + uint8_t * dst_spad_base = octx->dst_spad.data + ith * octx->dst_spad.size_per_thread; \ + \ + dma_queue * dma = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX vectorized PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_u(dst_ptr, 0.0f, ne0); + } else { + const uint8_t * src_ptr = pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (lp0 > 0) { + hvx_splat_f32_u(dst_ptr, 0.0f, (uint32_t)lp0); + } + + uint8_t * dst_row_start = dst_ptr + (size_t)lp0 * type_size; + if (nb00 == type_size) { + hvx_copy_f32_uu(dst_row_start, src_ptr, ne00); + } else { + for (uint32_t i = 0; i < ne00; i++) { + memcpy(dst_row_start + i * type_size, + src_ptr + (size_t)i * nb00, + type_size); + } + } + + if (rp0 > 0) { + hvx_splat_f32_u(dst_ptr + ((size_t)lp0 + ne00) * type_size, 0.0f, (uint32_t)rp0); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline before the main loop begins. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + const uint8_t * src_ptr = interior + ? pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3) : NULL; + + // Interior row: real DMA (1 row) from DDR to VTCM. + // Border row: null DMA (nrows=0) + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + src_ptr ? src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, src_ptr ? 1 : 0); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, compute in VTCM with aligned HVX ops, + // push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + } else { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + + uint8_t * dst_interior = dst_spad_cur + (size_t)lp0 * type_size; + + if ((uintptr_t)dst_interior % VLEN == 0) { + hvx_copy_f32_aa(dst_interior, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_interior, src_spad_cur, ne00); + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t ni1, ni2, ni3; + pad_decompose_row(next_row, ne1, ne2, &ni1, &ni2, &ni3); + const int next_interior = pad_is_interior(ni1, ni2, ni3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + const uint8_t * next_src_ptr = next_interior + ? pad_src_row_ptr(src, ni1, ni2, ni3, lp1, lp2, lp3) : NULL; + + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + next_src_ptr ? next_src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, next_src_ptr ? 1 : 0); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX circular PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + const uint8_t * src_row = pad_circ_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (nb00 == type_size) { + + if (lp0 > 0) { + if ((uint32_t)lp0 < 32) { + memcpy(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (uint32_t)lp0); + } + } + hvx_copy_f32_uu(dst_ptr + (size_t)lp0 * type_size, src_row, ne00); + if (rp0 > 0) { + if ((uint32_t)rp0 < 32) { + memcpy(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (size_t)rp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (uint32_t)rp0); + } + } + } else { + for (uint32_t i = 0; i < (uint32_t)lp0; i++) { + *(float *)(dst_ptr + i * type_size) = + *(const float *)(src_row + (size_t)(ne00 - (uint32_t)lp0 + i) * nb00); + } + for (uint32_t i = 0; i < ne00; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + for (uint32_t i = 0; i < (uint32_t)rp0; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + ne00 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA circular PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline. Every row is a real src DMA (no null DMAs). + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t pi1, pi2, pi3; + pad_decompose_row(ir, ne1, ne2, &pi1, &pi2, &pi3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, pad_circ_src_row_ptr(src, pi1, pi2, pi3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, assemble circular row in VTCM with + // aligned HVX ops, push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + + if (lp0 > 0) { + uint8_t * dst_left = dst_spad_cur; + const uint8_t * src_left = src_spad_cur + (size_t)(ne00 - (uint32_t)lp0) * type_size; + if ((uint32_t)lp0 < 32) { + memcpy(dst_left, src_left, (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_left, src_left, (uint32_t)lp0); + } + } + + { + uint8_t * dst_mid = dst_spad_cur + (size_t)lp0 * type_size; + if ((uintptr_t)dst_mid % VLEN == 0) { + hvx_copy_f32_aa(dst_mid, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_mid, src_spad_cur, ne00); + } + } + + if (rp0 > 0) { + uint8_t * dst_right = dst_spad_cur + ((size_t)lp0 + ne00) * type_size; + if ((uint32_t)rp0 < 32) { + memcpy(dst_right, src_spad_cur, (size_t)rp0 * type_size); + } else { + if ((uintptr_t)dst_right % VLEN == 0) { + hvx_copy_f32_aa(dst_right, src_spad_cur, (uint32_t)rp0); + } else { + hvx_copy_f32_ua(dst_right, src_spad_cur, (uint32_t)rp0); + } + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t nri1, nri2, nri3; + pad_decompose_row(next_row, ne1, ne2, &nri1, &nri2, &nri3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + pad_circ_src_row_ptr(src, nri1, nri2, nri3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_pad(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + // Only F32 supported + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + default: + FARF(ERROR, "pad-hvx: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const int32_t lp0 = octx->op_params[0]; + const int32_t rp0 = octx->op_params[1]; + const int32_t lp1 = octx->op_params[2]; + const int32_t rp1 = octx->op_params[3]; + const int32_t lp2 = octx->op_params[4]; + const int32_t rp2 = octx->op_params[5]; + const int32_t lp3 = octx->op_params[6]; + const int32_t rp3 = octx->op_params[7]; + const int32_t circular = octx->op_params[8]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne00 = src0->ne[0]; + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows > 0 ? total_dst_rows : 1); + + const size_t src_row_size = (size_t)ne00 * type_size; + const size_t dst_row_size = (size_t)ne0 * type_size; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // Total VTCM needed: 2 buffers (ping+pong) for src and dst, per thread + const size_t vtcm_needed = (size_t)n_threads * 2 * (src_row_size_aligned + dst_row_size_aligned); + + const int use_dma = (src0->nb[0] == (uint32_t)type_size) && + (ne00 >= 512) && + (octx->ctx->vtcm_base != NULL) && + (octx->ctx->vtcm_size >= vtcm_needed); + + if (use_dma) { + octx->src0_spad.size_per_thread = 2 * src_row_size_aligned; + octx->dst_spad.size_per_thread = 2 * dst_row_size_aligned; + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + } + + struct htp_pad_context pctx = { + .octx = octx, + .lp0 = lp0, .rp0 = rp0, + .lp1 = lp1, .rp1 = rp1, + .lp2 = lp2, .rp2 = rp2, + .lp3 = lp3, .rp3 = rp3, + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + .src_row_size = src_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size = dst_row_size, + .dst_row_size_aligned = dst_row_size_aligned, + }; + + FARF(HIGH, "pad-hvx%s%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) pads=(%d,%d,%d,%d,%d,%d,%d,%d)\n", + circular ? "-circ" : "", + use_dma ? "-dma" : "", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + + if (circular && use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular_dma, &pctx, n_threads); } + else if (circular) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular, &pctx, n_threads); } + else if (use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_dma, &pctx, n_threads); } + else { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx, &pctx, n_threads); } + + return HTP_STATUS_OK; +} + diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 1d8b0796bc9..b398e19f06e 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -18,9 +18,11 @@ #include "htp-ops.h" #include "htp-ops.h" -// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h +// Redefined the rope type constants as we can't include ggml.h #define HTP_ROPE_TYPE_NORMAL 0 #define HTP_ROPE_TYPE_NEOX 2 +#define HTP_ROPE_TYPE_MROPE 8 +#define HTP_ROPE_TYPE_IMROPE 40 #define HTP_ROPE_SPAD_NROWS 16 #define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2) @@ -82,7 +84,30 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { return (1 - MIN(1, MAX(0, y))); } -static void rope_cache_init(const float theta_base, +// Compute one (cos, sin) pair into cache[i0], cache[i0+1] applying YaRN scaling. +static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dims, + uint32_t i0, float ext_factor, float mscale, + float * cache) { + float theta_extrap = theta; + + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta_final = theta_interp; + float mscale_final = mscale; + + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + + cache[i0 + 0] = cosf(theta_final) * mscale_final; + cache[i0 + 1] = sinf(theta_final) * mscale_final; +} + +static __attribute__((noinline)) void rope_cache_init(const float theta_base, const float freq_scale, const float * freq_factors, float * corr_dims, @@ -96,26 +121,62 @@ static void rope_cache_init(const float theta_base, for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); - float theta_extrap = theta / ff; - - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta_final = theta_interp; - float mscale_final = mscale; + theta *= theta_scale; + } +} - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; +// pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra). +// sections[4]: number of head dims assigned to each position component. +static __attribute__((noinline)) void mrope_cache_init(const float pos_t, + const float pos_h, + const float pos_w, + const float pos_e, + const int32_t sections[4], + const bool is_imrope, + const float freq_scale, + const float * freq_factors, + float * corr_dims, + const uint32_t ne0, + const float ext_factor, + const float mscale, + float * cache, + const float theta_scale) { + const int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + const int sec_w = sections[0] + sections[1]; + const int sec_e = sec_w + sections[2]; + + float theta_t = pos_t; + float theta_h = pos_h; + float theta_w = pos_w; + float theta_e = pos_e; - // Get n-d magnitude scaling corrected for interpolation - mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + const int sector = (i0 / 2) % sect_dims; + + float theta; + if (is_imrope) { + // Interleaved: sector mod 3 selects component + if (sector % 3 == 0 && sector < 3 * sections[0]) { theta = theta_t; } + else if (sector % 3 == 1 && sector < 3 * sections[1]) { theta = theta_h; } + else if (sector % 3 == 2 && sector < 3 * sections[2]) { theta = theta_w; } + else { theta = theta_e; } + } else { + // Contiguous sections + if (sector < sections[0]) { theta = theta_t; } + else if (sector < sec_w) { theta = theta_h; } + else if (sector < sec_e) { theta = theta_w; } + else { theta = theta_e; } } - cache[i0 + 0] = cosf(theta_final) * mscale_final; - cache[i0 + 1] = sinf(theta_final) * mscale_final; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); - theta *= theta_scale; + theta_t *= theta_scale; + theta_h *= theta_scale; + theta_w *= theta_scale; + theta_e *= theta_scale; } } @@ -274,7 +335,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { uint64_t tt = HAP_perf_get_qtimer_count(); const int32_t mode = rctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + // MROPE and IMROPE use NEOX-style pairing for the rotation + const bool is_neox = (mode & HTP_ROPE_TYPE_NEOX) || (mode & HTP_ROPE_TYPE_MROPE); // VTCM setup uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); @@ -326,8 +388,25 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { if (i2 != prev_i2) { prev_i2 = i2; - const int32_t p = pos[i2]; - rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); + const bool is_mrope = (rctx->mode & HTP_ROPE_TYPE_MROPE) != 0; + if (is_mrope) { + // src1 holds four position arrays stacked along ne0: + // pos[i2], pos[i2+ne2], pos[i2+ne2*2], pos[i2+ne2*3] + const bool is_imrope = (rctx->mode == HTP_ROPE_TYPE_IMROPE); + mrope_cache_init( + (float) pos[i2], + (float) pos[i2 + ne2], + (float) pos[i2 + ne2 * 2], + (float) pos[i2 + ne2 * 3], + rctx->sections, is_imrope, + rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } else { + rope_cache_init(pos[i2], rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index a28fd03e978..d574da2e2bc 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -20,55 +20,56 @@ #include "htp-ops.h" #include "hvx-utils.h" -#define htp_ssm_conv_tensors_preamble \ - const struct htp_tensor * restrict src0 = octx->src[0]; \ - const struct htp_tensor * restrict src1 = octx->src[1]; \ - const struct htp_tensor * restrict dst = octx->dst; \ - struct htp_spad * restrict src0_spad = &octx->src0_spad; \ - struct htp_spad * restrict src1_spad = &octx->src1_spad; \ - struct htp_spad * restrict dst_spad = &octx->dst_spad; \ - \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_ssm_conv_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_ssm_conv_context { struct htp_ops_context * octx; uint32_t nrows_per_thread; + uint32_t d_inner_tile; uint64_t t_start; }; -#define htp_ssm_conv_preamble \ +#define htp_ssm_conv_preamble \ struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \ - struct htp_ops_context * octx = scctx->octx; \ - htp_ssm_conv_tensors_preamble; \ - dma_queue * dma_queue = octx->ctx->dma[ith]; + struct htp_ops_context * octx = scctx->octx; \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; // Scalar FP32 SSM_CONV implementation static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { @@ -128,118 +129,211 @@ static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *da dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension -static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { - htp_ssm_conv_preamble; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - const int nc = src1->ne[0]; // d_conv - const int ncs = src0->ne[0]; // d_conv - 1 + n_t +// In-register 32x32 fp32 transpose using std 5-stage HVX vshuff butterfly. +static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) { + HVX_Vector tmp[32]; - const uint32_t d_conv = src1->ne[0]; - const uint32_t d_inner = src0->ne[1]; - const uint32_t n_t = dst->ne[1]; - const uint32_t n_s = dst->ne[2]; + // Stage 0 (R = -4): pair (2i, 2i+1) for i = 0..15. m -> tmp. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[2*i + 1], m[2*i], -4); + tmp[2*i + 0] = Q6_V_lo_W(p); + tmp[2*i + 1] = Q6_V_hi_W(p); + } - const float * src0_data = (const float *) src0->data; - const float * src1_data = (const float *) src1->data; - float * dst_data = (float *) dst->data; + // Stage 1 (R = -8): per block of 4, pair (b+0, b+2) and (b+1, b+3). tmp -> m. + for (int b = 0; b < 32; b += 4) { + HVX_VectorPair p0 = Q6_W_vshuff_VVR(tmp[b + 2], tmp[b + 0], -8); + HVX_VectorPair p1 = Q6_W_vshuff_VVR(tmp[b + 3], tmp[b + 1], -8); + m[b + 0] = Q6_V_lo_W(p0); m[b + 1] = Q6_V_hi_W(p0); + m[b + 2] = Q6_V_lo_W(p1); m[b + 3] = Q6_V_hi_W(p1); + } - // Calculate row range for this thread - const int dr = scctx->nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = MIN(ir0 + dr, d_inner); - const uint32_t ir = ir1 - ir0; + // Stage 2 (R = -16): per block of 8, pair (b+i, b+i+4) for i = 0..3. m -> tmp. + for (int b = 0; b < 32; b += 8) { + for (int i = 0; i < 4; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[b + i + 4], m[b + i], -16); + tmp[b + 2*i + 0] = Q6_V_lo_W(p); + tmp[b + 2*i + 1] = Q6_V_hi_W(p); + } + } - if (ir0 >= ir1) { - return; // No work for this thread + // Stage 3 (R = -32): per block of 16, pair (b+i, b+i+8) for i = 0..7. tmp -> m. + for (int b = 0; b < 32; b += 16) { + for (int i = 0; i < 8; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(tmp[b + i + 8], tmp[b + i], -32); + m[b + 2*i + 0] = Q6_V_lo_W(p); + m[b + 2*i + 1] = Q6_V_hi_W(p); + } } - // src0 and src1 gather offsets - uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 }; - uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 }; + // Stage 4 (R = -64): pair (i, i+16) for i = 0..15. m -> tmp -> m. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[i + 16], m[i], -64); + tmp[2 * i + 0] = Q6_V_lo_W(p); + tmp[2 * i + 1] = Q6_V_hi_W(p); + } - for (uint32_t i = 0; i < VLEN_FP32; ++i) { - src0_offsets[i] = i * (ncs) * sizeof(float); - src1_offsets[i] = i * (d_conv) * sizeof(float); + for (int i = 0; i < 32; ++i) { + m[i] = tmp[i]; } +} - const uint32_t src0_gather_len = VLEN * ncs; - const uint32_t src1_gather_len = VLEN * d_conv; +// HVX FP32 SSM_CONV implementation - channel-vectorized HVX kernel with src0/src1 +// transposed into VTCM. +// +// VTCM layouts (per thread): +// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small). +// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile. +// +// d_inner_tile is chosen so that per-thread VTCM stays under the budget. +// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially. +#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread + +// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM) +static inline void transpose_src1(const float * src1_data, + uint32_t src1_stride_inner, + uint32_t i1_off, + uint32_t d_inner_per_thread, + uint32_t d_conv, + float * src1_T) { + for (uint32_t i = 0; i < d_inner_per_thread; ++i) { + const float * src_row = src1_data + (i1_off + i) * src1_stride_inner; + for (uint32_t j = 0; j < d_conv; ++j) { + src1_T[j * d_inner_per_thread + i] = src_row[j]; + } + } +} - // gather scratchpads - HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0); - HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN); +// HVX 32x32 src0 transpose: src0 {ncs, d_inner} (DDR) -> src0_T {d_inner_tile, ncs} (VTCM) +static inline void transpose_src0_block(const float * src0_block, + uint32_t ncs, + uint32_t cb_n, + uint32_t d_inner_tile, + float * src0_T_block_dst, + uint32_t cb /* dst column offset */) { + const uint32_t T_TILE = VLEN_FP32; + + HVX_Vector __attribute__((aligned(VLEN))) sub[32]; + + for (uint32_t t0 = 0; t0 < ncs; t0 += T_TILE) { + const uint32_t t_n = MIN(T_TILE, ncs - t0); + + // Load 32 rows (channels) of T_TILE samples; pad missing channels with zeros. + for (uint32_t r = 0; r < cb_n; ++r) { + const float * src_row = src0_block + r * ncs + t0; + if (t_n == T_TILE) { + sub[r] = *(const HVX_UVector *) src_row; + } else { + HVX_Vector v = hvx_vec_splat_f32(0.0f); + hvx_vec_store_u(&v, t_n * sizeof(float), hvx_vec_splat_f32(0.0f)); + + float __attribute__((aligned(VLEN))) tmp[VLEN_FP32] = { 0 }; + for (uint32_t k = 0; k < t_n; ++k) tmp[k] = src_row[k]; + v = *(const HVX_Vector *) tmp; + sub[r] = v; + } + } + for (uint32_t r = cb_n; r < T_TILE; ++r) { + sub[r] = hvx_vec_splat_f32(0.0f); + } - float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]); - float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]); + hvx_transpose_32x32_f32(sub); - uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; - uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + // Store transposed sub-tile to src0_T at offsets (t0 + j) * d_inner_tile + cb. + // Only write the valid t_n rows of the transposed result. + for (uint32_t r = 0; r < t_n; ++r) { + float * dst = src0_T_block_dst + (t0 + r) * d_inner_tile + cb; + if (cb_n == T_TILE) { + *(HVX_UVector *) dst = sub[r]; + } else { + hvx_vec_store_u(dst, cb_n * sizeof(float), sub[r]); + } + } + } +} - // copy src1 workload to VTCM - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir); +static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; - // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir); + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); - for (uint32_t i3 = 0; i3 < n_s; ++i3) { - float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + const uint32_t ncs = src0->ne[0]; - // copy src0 workload to VTCM - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir); + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); - // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir); + const uint32_t dr = scctx->nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, d_inner); - dma_queue_flush(dma_queue); + if (ir0 >= ir1) { + return; + } - for (uint32_t i2 = 0; i2 < n_t; ++i2) { - float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); + const uint32_t d_inner_per_thread = ir1 - ir0; + const uint32_t d_inner_tile = scctx->d_inner_tile; - const uint32_t nvec = ir / VLEN_FP32; - const uint32_t nloe = ir % VLEN_FP32; - uint32_t i1 = 0; + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; - for (uint32_t vi1 = 0; vi1 < nvec; vi1++) { - HVX_Vector acc_vec = Q6_V_vsplat_R(0); + // Per-thread VTCM regions. + float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread); + float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread); - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); - uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); - Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + // Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout. + transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T); - HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); - acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); - } + const uint32_t C_TILE = VLEN_FP32; - *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec); - i1 += VLEN_FP32; - } + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t tile_off = 0; tile_off < d_inner_per_thread; tile_off += d_inner_tile) { + const uint32_t tile_n = MIN(d_inner_tile, d_inner_per_thread - tile_off); - if (nloe) { - HVX_Vector acc_vec = Q6_V_vsplat_R(0); + // Place src0 chunk into VTCM in {d_inner_tile, ncs} layout. + const float * src0_block = src0_data + i3 * src0_stride_seq + (ir0 + tile_off) * src0_stride_inner; - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); - uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); - Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + transpose_src0_block(src0_block + cb * src0_stride_inner, ncs, cb_n, d_inner_tile, src0_T, cb); + } - HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); - acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + for (uint32_t t = 0; t < n_t; ++t) { + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + + HVX_Vector acc = hvx_vec_splat_f32(0.0f); + for (uint32_t j = 0; j < d_conv; ++j) { + HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb); + HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb); + acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w)); + } + HVX_Vector res = Q6_Vsf_equals_Vqf32(acc); + + float * dst_ptr = dst_data + i3 * dst_stride_seq + t * dst_stride_token + (ir0 + tile_off + cb); + if (cb_n == C_TILE) { + *(HVX_UVector *) dst_ptr = res; + } else { + hvx_vec_store_u(dst_ptr, cb_n * sizeof(float), res); + } } - - hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec)); } } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) tile=%u * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, d_inner_tile, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } @@ -264,46 +358,44 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t use_hvx = 0; - if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { - int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && - hex_is_aligned((void *) src1->data, VLEN) && - hex_is_aligned((void *) dst->data, VLEN); - - if (is_aligned) { - use_hvx = 1; - } + if (d_inner >= VLEN_FP32 && n_t >= VLEN_FP32) { + use_hvx = 1; } - if (use_hvx) { - scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread - scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even + scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; + scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); - octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256); - octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256); - octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256); + const uint32_t d_inner_per_thread = scctx.nrows_per_thread; + const uint32_t ncs = src0->ne[0]; - octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; + const uint32_t src1_T_size = hex_round_up(d_conv * d_inner_per_thread * sizeof(float), 256); + const uint32_t src0_T_max = HTP_SSM_CONV_VTCM_BUDGET > src1_T_size ? HTP_SSM_CONV_VTCM_BUDGET - src1_T_size : 0; - // Compute gather scratchpad size for src0 and src1 - const size_t gather_spad_size = n_threads * VLEN * 2; + uint32_t d_inner_tile = (src0_T_max / sizeof(float)) / ncs; + d_inner_tile -= (d_inner_tile % VLEN_FP32); + if (d_inner_tile == 0) { + FARF(HIGH, "ssm_conv-f32: inner tile rounds to 0 (ncs=%u), falling back to scalar\n", ncs); + use_hvx = 0; + } else { + scctx.d_inner_tile = d_inner_tile; - octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; + octx->src0_spad.size_per_thread = hex_round_up(d_inner_tile * ncs * sizeof(float), 256); + octx->src1_spad.size_per_thread = src1_T_size; + octx->dst_spad.size_per_thread = 0; - FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", - gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, - octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, - octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data); + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = 0; - const size_t total_spad_size = - gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; - if (total_spad_size > octx->ctx->vtcm_size) { - FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size, - octx->ctx->vtcm_size); + const size_t total_spad = octx->src0_spad.size + octx->src1_spad.size; + if (total_spad > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: scratchpad %zu exceeds VTCM %zu, falling back to scalar\n", + total_spad, octx->ctx->vtcm_size); use_hvx = 0; } } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index d4ae89ee6f0..40d2d60153a 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -17,7 +17,6 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "htp-ops.h" -#include "htp-ops.h" struct htp_unary_context { struct htp_ops_context * octx; @@ -159,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void hvx_fast_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + (void)pad; + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares and sum of values for full vectors + HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Reduce HVX sums + sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v)); + sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v); + HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v); + HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v)); + HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v); + HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v); + + // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); + HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v))); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + v_dst[i] = Q6_Vsf_equals_Vqf32(v3); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v3); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + static void scale_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -197,6 +269,24 @@ static void rms_norm_f32(const float * restrict src, } } +static void norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); + } +} + static void sqr_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -277,6 +367,95 @@ static void sigmoid_f32(const float * restrict src, } } +static void tri_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + const uint32_t ir, + const struct htp_unary_context * uctx) { + + const int32_t ttype = op_params[0]; + const HVX_Vector zero = hvx_vec_splat_f32(0.0f); + const uint32_t nvec = row_elems / VLEN_FP32; + const uint32_t nloe = row_elems % VLEN_FP32; + + const uint32_t ne01 = uctx->octx->src[0]->ne[1]; + + for (uint32_t b = 0; b < num_rows; b++) { + const uint32_t abs_row = ir + b; + const uint32_t i01 = abs_row % ne01; + + const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size); + HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size); + + uint32_t boundary; + int keep_left; + switch (ttype) { + case 0: boundary = i01; keep_left = 0; break; // keep col >= row + case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row + case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row + case 3: boundary = i01; keep_left = 1; break; // keep col < row + default: boundary = 0; keep_left = 0; break; + } + if (boundary > row_elems) boundary = row_elems; + + // Full HVX vectors — each starts at a 128-byte aligned offset + for (uint32_t i = 0; i < nvec; i++) { + const uint32_t vec_start = i * VLEN_FP32; + const uint32_t vec_end = vec_start + VLEN_FP32; + if (keep_left) { + if (vec_end <= boundary) { + v_dst[i] = v_src[i]; + } else if (vec_start >= boundary) { + v_dst[i] = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero); + } + } else { + if (vec_end <= boundary) { + v_dst[i] = zero; + } else if (vec_start >= boundary) { + v_dst[i] = v_src[i]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]); + } + } + } + + // Tail elements (row_elems not a multiple of VLEN_FP32) + if (nloe > 0) { + const uint32_t vec_start = nvec * VLEN_FP32; + const uint32_t vec_end = vec_start + nloe; + HVX_Vector tail_val; + if (keep_left) { + if (vec_end <= boundary) { + tail_val = v_src[nvec]; + } else if (vec_start >= boundary) { + tail_val = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero); + } + } else { + if (vec_end <= boundary) { + tail_val = zero; + } else if (vec_start >= boundary) { + tail_val = v_src[nvec]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]); + } + } + hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val); + } + } +} + static void softplus_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -468,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * // Process block in VTCM switch (htp_op) { + case HTP_OP_NORM: + norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; case HTP_OP_RMS_NORM: rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -498,6 +680,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_L2_NORM: l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_TRI: + tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx); + break; default: break; } @@ -541,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const char * op_type = NULL; switch (octx->op) { + case HTP_OP_NORM: + op_type = "norm-f32"; + break; case HTP_OP_RMS_NORM: op_type = "rmsnorm-f32"; break; @@ -571,6 +759,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_L2_NORM: op_type = "l2norm-f32"; break; + case HTP_OP_TRI: + op_type = "tri-f32"; + break; + default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; @@ -640,6 +832,22 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { return err; } +int op_tri(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src[0]->type) { + case HTP_TYPE_F32: + err = execute_op_unary_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} + int op_unary(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f0147af84c1..ba006d9b31a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G + // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = op->src[5]->ne[1]; const int nsg = op->src[2]->ne[0]/32; @@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( GGML_ASSERT(ne20 % 32 == 0); snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); - snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -1894,7 +1897,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l char base[256]; char name[256]; - snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); + // note: this is slower + //const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0; + const bool is_c4 = false; + + snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1904,6 +1911,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index fab7891c008..780dfe81bb3 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -672,7 +672,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { ![[dev->mtl_device name] containsString:@"M6"] && ![[dev->mtl_device name] containsString:@"A19"] && ![[dev->mtl_device name] containsString:@"A20"]) { - GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); + GGML_LOG_INFO("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); dev->props.has_tensor = false; } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a114391c2e8..206af227a2c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -564,9 +564,20 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(1024, ne0); + int nth = std::min(256, ne0); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + if (nth < 256) { + nrptg = std::min((256 + nth - 1) / nth, ne1); + if (nrptg * nth > 256) { + nrptg = 256 / nth; + } + } + + const int nw0 = (ne1 + nrptg - 1) / nrptg; + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0, ne2, ne3, nth, nrptg, 1); return 1; } @@ -816,9 +827,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); } else { const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - const int nth = MIN(args.ne00, nth_max); - const int nk0 = (args.ne00 + nth - 1)/nth; ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); @@ -1788,7 +1797,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { nk0 = ne10/ggml_blck_size(op->type); } - int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0*ne11, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1799,7 +1808,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } @@ -1863,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nk0 = ne00/ggml_blck_size(op->type); } - int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0*ne01, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1874,7 +1883,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } @@ -4039,14 +4048,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); - const int nth = std::min(1024, ne0); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const int nth = MIN(args.ne0, nth_max); + const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel! ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2d45de8cce2..e772664ba91 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; #if 1 template @@ -2548,21 +2549,24 @@ kernel void kernel_gated_delta_net_impl( uint3 ntg[[threads_per_threadgroup]]) { #define S_v FC_gated_delta_net_ne20 #define G FC_gated_delta_net_ne30 +#define K FC_gated_delta_net_K const uint tx = tpitg.x; const uint ty = tpitg.y; - const uint i23 = tgpig.z; // B - const uint i21 = tgpig.y; // H - const uint i20 = tgpig.x*NSG + ty; + const uint i23 = tgpig.z; // B (n_seqs) + const uint i21 = tgpig.y; // H (head) + const uint i20 = tgpig.x*NSG + ty; // row within S_v const uint i01 = i21 % args.ne01; const uint i11 = i21 % args.ne11; const float scale = 1.0f / sqrt((float)S_v); + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2580,6 +2584,17 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = (int)args.ne22 - (int)K; + + // output state base offset: after attention scores + const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; + // output state per-slot size: S_v * S_v * H * n_seqs + const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23; + // per-(seq,head) offset within a slot + const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + for (short t = 0; t < args.ne22; t++) { float s_k = 0.0f; @@ -2627,17 +2642,30 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + if (K > 1) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + } + } - FOR_UNROLL (short j = 0; j < NSG; j++) { - const short is = tx*NSG + j; - dst_state[is] = ls[j]; + if (K == 1) { + device float * dst_state = (device float *) (dst) + attn_size + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } } #undef S_v #undef G +#undef K } typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; @@ -5076,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32( for (int64_t sx = x_min; sx < x_max; ++sx) { const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0); const float w = wx * wy; - const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); + device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); sum += (*src_ptr) * w; wsum += w; } @@ -5258,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32( const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx)); const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3; - const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); + device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); sum += (*src_ptr) * wx * wy; } } @@ -5301,42 +5329,46 @@ kernel void kernel_roll_f32( } } -kernel void kernel_pad_f32( +template +kernel void kernel_pad_impl( constant ggml_metal_kargs_pad & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { + const int32_t i3 = tgpig.z; + const int32_t i2 = tgpig.y; + const int32_t k0 = tgpig.x/args.ne1; + const int32_t i1 = tgpig.x - k0*args.ne1; - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; + const int32_t i03 = i3; + const int32_t i02 = i2; + const int32_t i01 = i1; - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); - device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - if (i0 < args.ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; - } + for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) { + const int32_t i0 = k0*1024 + tpitg.x + l0; + if (i0 >= args.ne0) { + break; } - return; - } - - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } } } +typedef decltype(kernel_pad_impl) kernel_pad_t; + +template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl; +template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl; + +// TODO: this is slow - optimize kernel void kernel_pad_reflect_1d_f32( constant ggml_metal_kargs_pad_reflect_1d & args, device const char * src0, @@ -7300,23 +7332,27 @@ kernel void kernel_cpy_t_t( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; break; @@ -7348,23 +7384,27 @@ kernel void kernel_cpy_f32_q( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; + const int32_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); quantize_func(src, dst_data[i00]); @@ -7389,24 +7429,28 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; @@ -7442,7 +7486,11 @@ kernel void kernel_concat( const int i3 = tgpig.z; const int i2 = tgpig.y; - const int i1 = tgpig.x; + const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y; + + if (i1 >= args.ne1) { + return; + } int o[4] = {0, 0, 0, 0}; o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index c6aba608736..f75d089b574 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -110,6 +110,12 @@ set(GGML_OPENCL_KERNELS gemv_moe_q5_0_f32_ns gemm_moe_q5_1_f32_ns gemv_moe_q5_1_f32_ns + gemm_moe_q4_k_f32_ns + gemv_moe_q4_k_f32_ns + gemm_moe_q5_k_f32_ns + gemv_moe_q5_k_f32_ns + gemm_moe_q6_k_f32_ns + gemv_moe_q6_k_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0e511592d53..42286435bc6 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -375,6 +375,11 @@ struct ggml_backend_opencl_device_context { ggml_backend_buffer_type buffer_type; cl_context context = nullptr; + + GPU_FAMILY gpu_family = GPU_FAMILY::UNKNOWN; + ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; + + size_t global_mem_size = 0; }; // backend context @@ -384,6 +389,18 @@ struct ggml_backend_opencl_context { cl_device_id device; std::string device_name; + ggml_cl_version platform_version; + ggml_cl_version opencl_c_version; + + // argsort is loaded in supports_op because its availability depends on how + // many workgroups are allowed, which requires kernel compilation. + bool kernels_loaded_argsort = false; + // flash attn is loaded in supports_op because it contains multiple variants + // and takes time to compile, so we want to only compile it when needed. + bool kernels_loaded_flash_attn = false; + // rest of the kernels are currently always loaded in alloc_buffer. + bool kernels_loaded = false; + std::string driver_version; GPU_FAMILY gpu_family; @@ -558,6 +575,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; + cl_kernel kernel_convert_block_q4_k_trans4_ns, kernel_restore_block_q4_k_trans4_ns; + cl_kernel kernel_convert_block_q5_k_trans4_ns, kernel_restore_block_q5_k_trans4_ns; + cl_kernel kernel_convert_block_q6_k_trans4_ns, kernel_restore_block_q6_k_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -619,6 +639,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns; cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns; + cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns; + cl_kernel kernel_gemv_moe_q5_k_f32_ns, kernel_gemm_moe_q5_k_f32_ns; + cl_kernel kernel_gemv_moe_q6_k_f32_ns, kernel_gemm_moe_q6_k_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -638,11 +661,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_iq4_nl_f32_l4_lm; std::vector profiling_info; + std::vector profiling_results; - void write_profiling_info() { - FILE * fperf = fopen("cl_profiling.csv", "w"); - if (!fperf) { - GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + void flush_profiling_batch() { + if (profiling_info.empty()) { return; } @@ -666,6 +688,7 @@ struct ggml_backend_opencl_context { CL_CHECK(clGetEventProfilingInfo( info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL)); CL_CHECK(clReleaseEvent(info.evt)); + info.evt = nullptr; char kernel_name[512]; CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME, @@ -683,10 +706,26 @@ struct ggml_backend_opencl_context { info.cmd_complete_duration_ns = cmd_complete - cmd_end; info.cmd_total_duration_ns = cmd_complete - cmd_queued; } + profiling_results.insert(profiling_results.end(), + std::make_move_iterator(profiling_info.begin()), + std::make_move_iterator(profiling_info.end())); + profiling_info.clear(); + } + + void write_profiling_info() { + if (profiling_results.empty()) { + return; + } // Dump a csv + FILE * fperf = fopen("cl_profiling.csv", "w"); + if (!fperf) { + GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + return; + } + fprintf(fperf, "op name, kernel name, exec duration (ms), global size, local size, output size\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", info.op_name.c_str(), info.kernel_name.c_str(), info.cmd_duration_ns/1.e6f, @@ -697,14 +736,14 @@ struct ggml_backend_opencl_context { fclose(fperf); // Dump a simple chrome trace - FILE* ftrace = fopen("cl_trace.json", "w"); + FILE * ftrace = fopen("cl_trace.json", "w"); if (!ftrace) { GGML_LOG_ERROR("Failed to open cl_trace.json\n"); return; } fprintf(ftrace, "[\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", info.kernel_name.c_str(), info.cmd_queued/1000); fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", @@ -715,6 +754,7 @@ struct ggml_backend_opencl_context { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n", info.kernel_name.c_str(), info.cmd_end/1000); } + fprintf(ftrace, "]\n"); fclose(ftrace); } @@ -735,6 +775,9 @@ struct ggml_backend_opencl_context { profiling_info.emplace_back(); populateProfilingInfo(profiling_info.back(), evt, kernel, work_dim, global_work_size, local_work_size, tensor); + if (profiling_info.size() >= 2048) { + flush_profiling_batch(); + } #else GGML_UNUSED(tensor); CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, work_dim, NULL, global_work_size, local_work_size, 0, NULL, NULL)); @@ -775,11 +818,13 @@ struct ggml_backend_opencl_context { #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { + clFinish(queue); + ref_count--; if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING write_profiling_info(); - profiling_info.clear(); + profiling_results.clear(); #endif } } @@ -787,6 +832,9 @@ struct ggml_backend_opencl_context { // All registered devices with a default device in the front. static std::vector g_ggml_backend_opencl_devices; +// All device contexts associated with the devices above. +// The devices live as long as the process, so do the contexts. +static std::vector> g_ggml_backend_opencl_dev_ctxs; inline std::string read_file(const std::string &path) { std::ifstream ifs(path); @@ -830,12 +878,120 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co return p; } -static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { +static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // argsort + if (!backend_ctx->kernels_loaded_argsort) { + cl_int err; +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "argsort.cl.h" + }; +#else + const std::string kernel_src = read_file("argsort.cl"); +#endif + backend_ctx->program_argsort_f32_i32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); + backend_ctx->kernels_loaded_argsort = true; + } +} + +static void load_cl_kernels_flash_attn(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // flash_attn + if (!backend_ctx->kernels_loaded_flash_attn) { + cl_int err; + + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_f16 { + #include "flash_attn_f16.cl.h" + }; + const std::string kernel_src_f32 { + #include "flash_attn_f32.cl.h" + }; + const std::string kernel_src_f32_f16 { + #include "flash_attn_f32_f16.cl.h" + }; + #else + const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); + const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); + const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); + #endif + + if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { + const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { + { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, + {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, + {192, 192, 16, 16}, {256, 256, 16, 16}, + }; + + for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { + const int dk = fa_dims[i].dk; + const int dv = fa_dims[i].dv; + const int bm = fa_dims[i].bm; + const int bn = fa_dims[i].bn; + std::string OPTS = compile_opts + + " -D DK=" + std::to_string(dk) + + " -D DV=" + std::to_string(dv) + + " -D BLOCK_M=" + std::to_string(bm) + + " -D BLOCK_N=" + std::to_string(bn); + + cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); + cl_kernel k_f16, k_f16_q1; + CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); + CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; + backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; + CL_CHECK(clReleaseProgram(prog_f16)); + + cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); + cl_kernel k_f32, k_f32_q1; + CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); + CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; + backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; + CL_CHECK(clReleaseProgram(prog_f32)); + + cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); + cl_kernel k_f32_f16, k_f32_f16_q1; + CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); + CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; + backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; + CL_CHECK(clReleaseProgram(prog_f32_f16)); + + backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; + backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; + } + backend_ctx->kernels_loaded_flash_attn = true; + } + } +} + +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { + if (backend_ctx->kernels_loaded) { + return; + } + cl_int err; // compiler options for general kernels auto opencl_c_std = - std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); std::string compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable -cl-unsafe-math-optimizations" " -cl-finite-math-only -cl-fast-relaxed-math"; @@ -981,6 +1137,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_k_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -1974,89 +2136,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } - // flash_attn - { - #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_f16 { - #include "flash_attn_f16.cl.h" - }; - const std::string kernel_src_f32 { - #include "flash_attn_f32.cl.h" - }; - const std::string kernel_src_f32_f16 { - #include "flash_attn_f32_f16.cl.h" - }; - #else - const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); - const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); - const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); - #endif - - if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { - const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { - { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, - {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, - {192, 192, 16, 16}, {256, 256, 16, 16}, - }; - - for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { - const int dk = fa_dims[i].dk; - const int dv = fa_dims[i].dv; - const int bm = fa_dims[i].bm; - const int bn = fa_dims[i].bn; - std::string OPTS = compile_opts + - " -D DK=" + std::to_string(dk) + - " -D DV=" + std::to_string(dv) + - " -D BLOCK_M=" + std::to_string(bm) + - " -D BLOCK_N=" + std::to_string(bn); - - cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); - cl_kernel k_f16, k_f16_q1; - CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); - CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; - backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; - CL_CHECK(clReleaseProgram(prog_f16)); - - cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); - cl_kernel k_f32, k_f32_q1; - CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); - CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; - backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; - CL_CHECK(clReleaseProgram(prog_f32)); - - cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); - cl_kernel k_f32_f16, k_f32_f16_q1; - CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); - CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; - backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; - CL_CHECK(clReleaseProgram(prog_f32_f16)); - - backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; - backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; - } - GGML_LOG_CONT("."); - } - } - - // argsort - { -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "argsort.cl.h" - }; -#else - const std::string kernel_src = read_file("argsort.cl"); -#endif - backend_ctx->program_argsort_f32_i32 = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); - GGML_LOG_CONT("."); - } - // div { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3071,133 +3150,235 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } - // gemv_moe_mxfp4_f32_ns + // gemv_moe_q4_k_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "gemv_moe_mxfp4_f32_ns.cl.h" + #include "gemv_moe_q4_k_f32_ns.cl.h" }; #else - const std::string kernel_src = read_file("gemv_moe_mxfp4_f32_ns.cl"); + const std::string kernel_src = read_file("gemv_moe_q4_k_f32_ns.cl"); #endif cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_k_f32_ns", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // gemm_moe_mxfp4_f32_ns + // gemm_moe_q4_k_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "gemm_moe_mxfp4_f32_ns.cl.h" + #include "gemm_moe_q4_k_f32_ns.cl.h" }; #else - const std::string kernel_src = read_file("gemm_moe_mxfp4_f32_ns.cl"); + const std::string kernel_src = read_file("gemm_moe_q4_k_f32_ns.cl"); #endif cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_k_f32_ns", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // moe_reorder_b + // gemv_moe_q5_k_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "moe_reorder_b.cl.h" + #include "gemv_moe_q5_k_f32_ns.cl.h" }; #else - const std::string kernel_src = read_file("moe_reorder_b.cl"); + const std::string kernel_src = read_file("gemv_moe_q5_k_f32_ns.cl"); #endif cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - CL_CHECK((backend_ctx->kernel_moe_reorder_b = clCreateKernel(prog, "kernel_moe_reorder_b", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_k_f32_ns", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // moe_sort_by_expert + // gemm_moe_q5_k_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "moe_sort_by_expert.cl.h" + #include "gemm_moe_q5_k_f32_ns.cl.h" }; #else - const std::string kernel_src = read_file("moe_sort_by_expert.cl"); + const std::string kernel_src = read_file("gemm_moe_q5_k_f32_ns.cl"); #endif cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - CL_CHECK((backend_ctx->kernel_moe_histogram = clCreateKernel(prog, "kernel_moe_histogram", &err), err)); - CL_CHECK((backend_ctx->kernel_moe_scan = clCreateKernel(prog, "kernel_moe_scan", &err), err)); - CL_CHECK((backend_ctx->kernel_moe_fill = clCreateKernel(prog, "kernel_moe_fill", &err), err)); - CL_CHECK((backend_ctx->kernel_moe_scatter = clCreateKernel(prog, "kernel_moe_scatter", &err), err)); + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_k_f32_ns", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // gemv_noshuffle_q6_k_f32 + // gemv_moe_q6_k_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "gemv_noshuffle_q6_k_f32.cl.h" + #include "gemv_moe_q6_k_f32_ns.cl.h" }; #else - const std::string kernel_src = read_file("gemv_noshuffle_q6_k_f32.cl"); + const std::string kernel_src = read_file("gemv_moe_q6_k_f32_ns.cl"); #endif - - std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable "; - if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - cl_program prog = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q6_K_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // gemm_noshuffle_q6_k_f32 + // gemm_moe_q6_k_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "gemm_noshuffle_q6_k_f32.cl.h" + #include "gemm_moe_q6_k_f32_ns.cl.h" }; #else - const std::string kernel_src = read_file("gemm_noshuffle_q6_k_f32.cl"); + const std::string kernel_src = read_file("gemm_moe_q6_k_f32_ns.cl"); #endif cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q6_K_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_gemm_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // gemv_noshuffle_q5_k_f32 + // gemv_moe_mxfp4_f32_ns { - std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable "; - if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; - } - #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "gemv_noshuffle_q5_k_f32.cl.h" + #include "gemv_moe_mxfp4_f32_ns.cl.h" }; #else - const std::string kernel_src = read_file("gemv_noshuffle_q5_k_f32.cl"); + const std::string kernel_src = read_file("gemv_moe_mxfp4_f32_ns.cl"); #endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - cl_program prog = build_program_from_source( + CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_mxfp4_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_mxfp4_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_mxfp4_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_reorder_b + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_reorder_b.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_reorder_b.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_reorder_b = clCreateKernel(prog, "kernel_moe_reorder_b", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_sort_by_expert + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_sort_by_expert.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_sort_by_expert.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_histogram = clCreateKernel(prog, "kernel_moe_histogram", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scan = clCreateKernel(prog, "kernel_moe_scan", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_fill = clCreateKernel(prog, "kernel_moe_fill", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scatter = clCreateKernel(prog, "kernel_moe_scatter", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q6_k_f32.cl"); +#endif + + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q6_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q5_k_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_k_f32.cl"); +#endif + + cl_program prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_k_f32", &err), err)); @@ -3221,13 +3402,15 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); + backend_ctx->kernels_loaded = true; } // XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // XXX static bool initialized = false; // XXX static ggml_backend_opencl_context *backend_ctx = nullptr; -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev); +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev); +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev); namespace /* anonymous */ { extern struct ggml_backend_device_i ggml_backend_opencl_device_i; @@ -3440,13 +3623,13 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r /* .context = */ dev_ctx.get(), }); - if (!ggml_cl2_init(&found_devices.back())) { + if (!ggml_opencl_is_device_supported(&found_devices.back())) { found_devices.pop_back(); - GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n"); + GGML_LOG_WARN("ggml_opencl: drop unsupported device '%s'.\n", dev->name); continue; } - dev_ctx.release(); + g_ggml_backend_opencl_dev_ctxs.push_back(std::move(dev_ctx)); } if (found_devices.size()) { @@ -3463,8 +3646,79 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r return found_devices; } +// check if device should be accepted +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { + GGML_ASSERT(dev); + GGML_ASSERT(dev->context); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + GGML_ASSERT(dev_ctx->platform); + GGML_ASSERT(dev_ctx->device); + + if (strstr(dev_ctx->device_name.c_str(), "Adreno") || + strstr(dev_ctx->device_name.c_str(), "Qualcomm") || + strstr(dev_ctx->device_version.c_str(), "Adreno")) { + dev_ctx->gpu_family = GPU_FAMILY::ADRENO; + + // Usually device version contains the detailed device name + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); + if (dev_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); + } + } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { + dev_ctx->gpu_family = GPU_FAMILY::INTEL; + } else { + GGML_LOG_WARN("ggml_opencl: unsupported GPU '%s'.\n", dev_ctx->device_name.c_str()); + dev_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + return false; + } + + ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); + + // Check device OpenCL version, OpenCL 2.0 or above is required + ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, dev_ctx->device); + if (opencl_c_version.major < 2) { + GGML_LOG_WARN("ggml_opencl: OpenCL 2.0 or above is required\n"); + return false; + } + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (dev_ctx->gpu_family != GPU_FAMILY::ADRENO) { + GGML_LOG_WARN("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " + "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); + return false; + } +#endif + + size_t ext_str_size; + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); + + char *ext_buffer = (char *)alloca(ext_str_size + 1); + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); + ext_buffer[ext_str_size] = '\0'; + + // Check if ext_buffer contains cl_khr_fp16 + bool fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; + if (!fp16_support) { + GGML_LOG_WARN("ggml_opencl: device does not support FP16\n"); + return false; + } + + // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes + // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) + if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && + strstr(ext_buffer, "cl_intel_subgroups") == NULL) { + GGML_LOG_WARN("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " + "(note that subgroups is an optional feature in OpenCL 3.0)\n"); + return false; + } + + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &dev_ctx->global_mem_size, NULL); + return true; +} + // Initialize device if it is supported (returns nullptr if it is not). -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { GGML_ASSERT(dev); GGML_ASSERT(dev->context); @@ -3486,33 +3740,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // when the associated device is initialized backend_ctx->ref_count = 0; - if (strstr(dev_ctx->device_name.c_str(), "Adreno") || - strstr(dev_ctx->device_name.c_str(), "Qualcomm") || - strstr(dev_ctx->device_version.c_str(), "Adreno")) { - backend_ctx->gpu_family = GPU_FAMILY::ADRENO; - // Usually device version contains the detailed device name - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); - if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); - } - + backend_ctx->gpu_family = dev_ctx->gpu_family; + backend_ctx->adreno_gen = dev_ctx->adreno_gen; + if (backend_ctx->gpu_family == GPU_FAMILY::ADRENO) { // Use wave size of 64 for all Adreno GPUs. backend_ctx->adreno_wave_size = 64; - } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { - backend_ctx->gpu_family = GPU_FAMILY::INTEL; - } else { - GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str()); - backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; - return nullptr; - } - -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { - GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " - "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); - return nullptr; } -#endif // Populate backend device name backend_ctx->device_name = dev_ctx->device_name; @@ -3521,13 +3754,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { cl_device_id device = backend_ctx->device; ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); - - // Check device OpenCL version, OpenCL 2.0 or above is required ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device); - if (opencl_c_version.major < 2) { - GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); - return nullptr; - } + + backend_ctx->platform_version = platform_version; + backend_ctx->opencl_c_version = opencl_c_version; // Check driver version size_t driver_version_str_size; @@ -3550,34 +3780,21 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { char *ext_buffer = (char *)alloca(ext_str_size + 1); clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + // check Adreno large buffer support backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; - // fp16 is required - if (!backend_ctx->fp16_support) { - GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); - return nullptr; - } - - // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes - // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) - if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && - strstr(ext_buffer, "cl_intel_subgroups") == NULL) { - GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " - "(note that subgroups is an optional feature in OpenCL 3.0)\n"); - return nullptr; - } - cl_uint base_align_in_bits; CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); GGML_ASSERT(base_align_in_bits % 8u == 0); backend_ctx->alignment = base_align_in_bits / 8u; GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); - clGetDeviceInfo(device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &backend_ctx->global_mem_size, NULL); + backend_ctx->global_mem_size = dev_ctx->global_mem_size; GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", backend_ctx->global_mem_size/1024/1024); clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); @@ -3665,8 +3882,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { #endif CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); - // Load kernels - load_cl_kernels(backend_ctx.get(), opencl_c_version); + // delay kernel loading until the first buffer is created + // load_cl_kernels(backend_ctx.get()); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Allocate intermediate buffers and images @@ -3708,22 +3925,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { return dev_ctx->backend_ctx; } -static void ggml_cl2_free(ggml_backend_t backend) { +static void ggml_cl_free(ggml_backend_t backend) { ggml_backend_opencl_context * ctx = (ggml_backend_opencl_context *) backend->context; ctx->free(); - - // The CL context is shared by all backends, release it if all backends have been released - bool should_release_opencl = true; - for (auto device : g_ggml_backend_opencl_devices) { - ggml_backend_opencl_device_context * ctx_dev = (ggml_backend_opencl_device_context *) device.context; - if (ctx_dev->backend_ctx->ref_count > 0) { - should_release_opencl = false; - } - } - - if (should_release_opencl) { - CL_CHECK(clReleaseContext(ctx->context)); - } } #ifdef GGML_OPENCL_USE_ADRENO_KERNELS @@ -4148,6 +4352,8 @@ struct ggml_tensor_extra_cl_iq4_nl { struct ggml_tensor_extra_cl_q4_K { // Quantized values cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; // Scales for each super block. cl_mem s = nullptr; // Scales @@ -4176,12 +4382,18 @@ struct ggml_tensor_extra_cl_q4_K { CL_CHECK(clReleaseMemObject(dm)); dm = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } } }; struct ggml_tensor_extra_cl_q5_K { // Lower 4 bits of quantized weights. cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; // Upper 1 bit of quantized weights. cl_mem qh = nullptr; // Scales for each block. @@ -4222,6 +4434,10 @@ struct ggml_tensor_extra_cl_q5_K { CL_CHECK(clReleaseMemObject(dm)); dm = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } size_q = 0; size_qh = 0; @@ -4234,6 +4450,8 @@ struct ggml_tensor_extra_cl_q5_K { struct ggml_tensor_extra_cl_q6_K { // Lower 4 bits of quantized weights. cl_mem ql = nullptr; + // Lower 4 bits as image1d_buffer_t + cl_mem ql_img = nullptr; // Upper 2 bits of quantized weights. cl_mem qh = nullptr; // Scales for each block. @@ -4267,6 +4485,10 @@ struct ggml_tensor_extra_cl_q6_K { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + if (ql_img != nullptr) { + CL_CHECK(clReleaseMemObject(ql_img)); + ql_img = nullptr; + } size_ql = 0; size_qh = 0; @@ -4289,7 +4511,7 @@ static const char * ggml_backend_opencl_name(ggml_backend_t backend) { } static void ggml_backend_opencl_free(ggml_backend_t backend) { - ggml_cl2_free(backend); + ggml_cl_free(backend); } static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -4328,14 +4550,17 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { // enqueued to it won't start until commands in the other devices have // completed. static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { - if (g_ggml_backend_opencl_devices.size() < 2) - return; // No other devices to synchronize with. + if (g_ggml_backend_opencl_devices.size() < 2) { + return; // No other devices to synchronize with. + } std::vector events; events.reserve(g_ggml_backend_opencl_devices.size()); for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) { - auto * other_backend_ctx = ggml_cl2_init(&backend_dev); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) backend_dev.context; + auto * other_backend_ctx = dev_ctx->backend_ctx; + if (backend_ctx != other_backend_ctx) { cl_event ev; CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev)); @@ -4488,7 +4713,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { GGML_UNUSED(backend_ctx); int ne01 = tensor->ne[1]; - return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 32 == 0); } inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { @@ -4700,7 +4925,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te // the quantizations here currently do not - they are only supported by Adreno with certain shapes if (op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || - op->src[0]->type == GGML_TYPE_Q5_1) { + op->src[0]->type == GGML_TYPE_Q5_1 || + op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || + op->src[0]->type == GGML_TYPE_Q6_K) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (op->src[1]->type == GGML_TYPE_F32) { return use_adreno_moe_kernels(backend_ctx, op->src[0]) @@ -4745,6 +4973,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_IM2COL: return true; case GGML_OP_ARGSORT: { + load_cl_kernels_argsort(backend_ctx); + cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); @@ -4762,6 +4992,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_FLASH_ATTN_EXT: { + load_cl_kernels_flash_attn(backend_ctx); + const ggml_tensor * q = op->src[0]; const ggml_tensor * k = op->src[1]; const ggml_tensor * v = op->src[2]; @@ -4829,7 +5061,7 @@ static ggml_backend_i ggml_backend_opencl_i = { ggml_backend_t ggml_backend_opencl_init(void) { ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(dev); ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_opencl_guid(), @@ -5208,15 +5440,13 @@ static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) } static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer->buft->device); - return (void *) (uintptr_t) backend_ctx->alignment; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + return (void *) (uintptr_t) dev_ctx->backend_ctx->alignment; } static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_cl2_init(buffer->buft->device); - if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); @@ -5256,7 +5486,8 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff } static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; @@ -6047,32 +6278,75 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; - if (use_adreno_kernels(backend_ctx, tensor)) { - kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; - } - #else - cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; - #endif +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_k_trans4_ns; - cl_uchar mask_0F = 0x0F; - cl_uchar mask_F0 = 0xF0; + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; + } +#else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; @@ -6157,14 +6431,58 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q5_K_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; - #endif +#endif cl_uchar mask_0F = 0x0F; cl_uchar mask_F0 = 0xF0; @@ -6232,6 +6550,79 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_buffer_region region; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno MoE Q6_K kernel needs special transposed layout + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + size_t moe_size_ql = (size_t)(ggml_nelements(tensor) / 8) * sizeof(uint32_t); // 4 bits per element + size_t moe_size_qh = (size_t)(ggml_nelements(tensor) / 16) * sizeof(uint32_t); // 2 bits per element + size_t moe_size_s = size_s; + size_t moe_size_d = size_d; + + // Subbuffer for ql + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = moe_size_ql; + CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + auto previous_origin = region.origin; + + // Subbuffer for qh + region.origin = align_to(previous_origin + moe_size_ql, backend_ctx->alignment); + region.size = moe_size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for scales + region.origin = align_to(previous_origin + moe_size_qh, backend_ctx->alignment); + region.size = moe_size_s; + CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for d + region.origin = align_to(previous_origin + moe_size_s, backend_ctx->alignment); + region.size = moe_size_d; + CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q6_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for ql + cl_image_format img_format_ql = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_ql = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->ql } + }; + extra->ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_ql, &img_desc_ql, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // Subbuffer for ql region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_ql; @@ -6331,7 +6722,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(tensor->extra); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context *backend_ctx = dev_ctx->backend_ctx; cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; @@ -6825,6 +7217,40 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_uchar mask_F0 = 0xF0; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6901,6 +7327,40 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_uchar mask_F0 = 0xF0; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6974,7 +7434,44 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { static ggml_cl_buffer buf_trans_ql; static ggml_cl_buffer buf_trans_qh; @@ -7070,8 +7567,9 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, } static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - ggml_backend_dev_t dev = buffer->buft->device; - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; + cl_command_queue queue = backend_ctx->queue; ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; @@ -7111,7 +7609,8 @@ static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer } static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device); + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(buffer_type->device); + load_cl_kernels(backend_ctx); // clCreateBuffer returns -61 for size 0 size = std::max(size, (size_t)1); @@ -7134,15 +7633,15 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b } static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - return backend_ctx->alignment; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + return dev_ctx->backend_ctx->alignment; } static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { static size_t max_size = -1; if (max_size == (size_t)-1) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - max_size = backend_ctx->max_alloc_size; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + max_size = dev_ctx->backend_ctx->max_alloc_size; } return max_size; } @@ -7179,14 +7678,13 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_ static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; - ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) dev_ctx->backend_ctx; static const size_t opencl_extra_margin = 1024ull*1024ull*1024ull; // OpenCL does not provide reliable currently-free device memory. // Use total/global memory as a best-effort upper bound. // Improved safety: Reduce by a 1GiB extra margin for common --fit - *total = backend_ctx->global_mem_size; + *total = dev_ctx->global_mem_size; *free = *total > opencl_extra_margin ? *total - opencl_extra_margin : 0; } @@ -7210,7 +7708,7 @@ static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct } static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_context * backend_ctx = ggml_cl_init(dev); // Getting a new reference to the backend, increase ref_count backend_ctx->ref_count++; @@ -7247,6 +7745,7 @@ static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_bac } static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + ggml_cl_init(dev); return ggml_opencl_supports_op(dev, op); } @@ -7259,8 +7758,8 @@ static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggm // Check cl_context is the same. clEnqueue* commands may not use // buffers from another cl_context. - ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev); - ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device); + ggml_backend_opencl_context * backend_ctx0 = ggml_cl_init(dev); + ggml_backend_opencl_context * backend_ctx1 = ggml_cl_init(buft->device); return backend_ctx0->context == backend_ctx1->context; } @@ -13733,6 +14232,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif @@ -13741,6 +14243,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, (void)extra0_q4_1; (void)extra0_q5_0; (void)extra0_q5_1; + (void)extra0_q4_K; + (void)extra0_q5_K; + (void)extra0_q6_K; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -13812,7 +14317,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14028,7 +14533,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14204,7 +14709,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14380,7 +14885,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14612,6 +15117,532 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #endif // GGML_OPENCL_SOA_Q break; } + case GGML_TYPE_Q4_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q6_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q6_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q6_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, src0)) { @@ -14633,7 +15664,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 8f06d570587..c25eabdd72b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -220,6 +220,10 @@ kernel void kernel_convert_block_q4_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_0; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -263,6 +267,10 @@ kernel void kernel_restore_block_q4_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_0; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -401,6 +409,10 @@ kernel void kernel_convert_block_q4_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_1; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -446,6 +458,10 @@ kernel void kernel_restore_block_q4_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_1; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -491,6 +507,10 @@ kernel void kernel_convert_block_q5_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_0; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -536,6 +556,10 @@ kernel void kernel_restore_block_q5_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_0; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -583,6 +607,10 @@ kernel void kernel_convert_block_q5_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_1; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -630,6 +658,10 @@ kernel void kernel_restore_block_q5_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_1; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -664,6 +696,415 @@ kernel void kernel_restore_block_q5_1_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +kernel void kernel_convert_block_q4_k_trans4_ns( + __global struct block_q4_K * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->q[i*32 + 2*j]; + uchar x1 = b->q[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_q[base + (p * 4 + 0) * ne01] = v.x; + dst_q[base + (p * 4 + 1) * ne01] = v.y; + dst_q[base + (p * 4 + 2) * ne01] = v.z; + dst_q[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q4_k_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q4_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_q[base + (p * 4 + 0) * ne01]; + qv[p].y = src_q[base + (p * 4 + 1) * ne01]; + qv[p].z = src_q[base + (p * 4 + 2) * ne01]; + qv[p].w = src_q[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q5_k_trans4_ns( + __global struct block_q5_K * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + for (int k = 0; k < 8; k++) { + uchar b0 = 0, b1 = 0, b2 = 0, b3 = 0; + for (int bit = 0; bit < 8; bit++) { + b0 |= (uchar)(((b->qh[bit] >> k) & 1) << bit); + b1 |= (uchar)(((b->qh[8 + bit] >> k) & 1) << bit); + b2 |= (uchar)(((b->qh[16 + bit] >> k) & 1) << bit); + b3 |= (uchar)(((b->qh[24 + bit] >> k) & 1) << bit); + } + uint packed = (uint)b0 | ((uint)b1 << 8) | ((uint)b2 << 16) | ((uint)b3 << 24); + dst_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01] = packed; + } + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->qs[i*32 + 2*j]; + uchar x1 = b->qs[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_qs[base + (p * 4 + 0) * ne01] = v.x; + dst_qs[base + (p * 4 + 1) * ne01] = v.y; + dst_qs[base + (p * 4 + 2) * ne01] = v.z; + dst_qs[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q5_k_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q5_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + for (int j = 0; j < 32; j++) b->qh[j] = 0; + for (int k = 0; k < 8; k++) { + uint packed = src_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01]; + uchar b0 = (uchar)(packed & 0xFF); + uchar b1 = (uchar)((packed >> 8) & 0xFF); + uchar b2 = (uchar)((packed >> 16) & 0xFF); + uchar b3 = (uchar)((packed >> 24) & 0xFF); + for (int bit = 0; bit < 8; bit++) { + b->qh[bit] |= (uchar)(((b0 >> bit) & 1) << k); + b->qh[8 + bit] |= (uchar)(((b1 >> bit) & 1) << k); + b->qh[16 + bit] |= (uchar)(((b2 >> bit) & 1) << k); + b->qh[24 + bit] |= (uchar)(((b3 >> bit) & 1) << k); + } + } + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_qs[base + (p * 4 + 0) * ne01]; + qv[p].y = src_qs[base + (p * 4 + 1) * ne01]; + qv[p].z = src_qs[base + (p * 4 + 2) * ne01]; + qv[p].w = src_qs[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q6_k_trans4_ns( + __global struct block_q6_K * src0, + __global uint * dst_ql, + __global uint * dst_qh, + __global half * dst_d, + __global char * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = src0 + src_blk_offset; + + dst_d[dst_blk_offset] = b->d; + + uint4 qlv[8]; + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->ql[i*64 + 2*j]; + uchar x1 = b->ql[i*64 + 2*j + 1]; + uchar x2 = b->ql[i*64 + 32 + 2*j]; + uchar x3 = b->ql[i*64 + 32 + 2*j + 1]; + qlv_bytes[i*64 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qlv_bytes[i*64 + j + 16] = convert_uchar(x2 & mask_0F) | convert_uchar((x3 & mask_0F) << 4); + qlv_bytes[i*64 + j + 32] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + qlv_bytes[i*64 + j + 48] = convert_uchar((x2 & mask_F0) >> 4) | convert_uchar(x3 & mask_F0); + } + } + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qlv[p]; + dst_ql[ql_base + (p * 4 + 0) * ne01] = v.x; + dst_ql[ql_base + (p * 4 + 1) * ne01] = v.y; + dst_ql[ql_base + (p * 4 + 2) * ne01] = v.z; + dst_ql[ql_base + (p * 4 + 3) * ne01] = v.w; + } + + uint qhv[16] = {0}; + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + uchar h = b->qh[n*32 + l]; + int u = l / 16; + int bit_pos = (l % 16) * 2; + qhv[(n*4 + 0)*2 + u] |= ((uint)((h >> 0) & 0x03)) << bit_pos; + qhv[(n*4 + 1)*2 + u] |= ((uint)((h >> 2) & 0x03)) << bit_pos; + qhv[(n*4 + 2)*2 + u] |= ((uint)((h >> 4) & 0x03)) << bit_pos; + qhv[(n*4 + 3)*2 + u] |= ((uint)((h >> 6) & 0x03)) << bit_pos; + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + + for (int p = 0; p < 16; ++p) { + dst_qh[qh_base + p * ne01] = qhv[p]; + } + + __global char * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + #pragma unroll + for (int i = 0; i < 16; ++i) { + s_dst[i] = b->scales[i]; + } +} + +kernel void kernel_restore_block_q6_k_trans4_ns( + __global uint * src_ql, + __global uint * src_qh, + __global half * src_d, + __global char * src_s, + __global struct block_q6_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + uint4 qlv[8]; + for (int p = 0; p < 8; ++p) { + qlv[p].x = src_ql[ql_base + (p * 4 + 0) * ne01]; + qlv[p].y = src_ql[ql_base + (p * 4 + 1) * ne01]; + qlv[p].z = src_ql[ql_base + (p * 4 + 2) * ne01]; + qlv[p].w = src_ql[ql_base + (p * 4 + 3) * ne01]; + } + + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo_02 = qlv_bytes[i*64 + j]; + uchar lo_13 = qlv_bytes[i*64 + j + 16]; + uchar hi_02 = qlv_bytes[i*64 + j + 32]; + uchar hi_13 = qlv_bytes[i*64 + j + 48]; + b->ql[i*64 + 2*j] = convert_uchar((lo_02 & mask_0F) | ((hi_02 & mask_0F) << 4)); + b->ql[i*64 + 2*j + 1] = convert_uchar(((lo_02 & mask_F0) >> 4) | (hi_02 & mask_F0)); + b->ql[i*64 + 32 + 2*j] = convert_uchar((lo_13 & mask_0F) | ((hi_13 & mask_0F) << 4)); + b->ql[i*64 + 32 + 2*j + 1] = convert_uchar(((lo_13 & mask_F0) >> 4) | (hi_13 & mask_F0)); + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + uint qhv[16]; + for (int p = 0; p < 16; ++p) { + qhv[p] = src_qh[qh_base + p * ne01]; + } + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + int u = l / 16; + int bit_pos = (l % 16) * 2; + uchar v0 = (uchar)((qhv[(n*4 + 0)*2 + u] >> bit_pos) & 0x03); + uchar v1 = (uchar)((qhv[(n*4 + 1)*2 + u] >> bit_pos) & 0x03); + uchar v2 = (uchar)((qhv[(n*4 + 2)*2 + u] >> bit_pos) & 0x03); + uchar v3 = (uchar)((qhv[(n*4 + 3)*2 + u] >> bit_pos) & 0x03); + b->qh[n*32 + l] = v0 | (v1 << 2) | (v2 << 4) | (v3 << 6); + } + } + + __global char * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + for (int i = 0; i < 16; ++i) { + b->scales[i] = s_src[i]; + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ @@ -762,6 +1203,10 @@ kernel void kernel_convert_block_mxfp4_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_MXFP4; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -805,6 +1250,10 @@ kernel void kernel_restore_block_mxfp4_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_MXFP4; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl index e404f392bdd..02cdbdd9fb1 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl @@ -163,7 +163,7 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -248,6 +248,10 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl index 02290c17eb1..d403ed0cab1 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl @@ -115,7 +115,7 @@ kernel void kernel_gemm_moe_q4_0_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -198,6 +198,10 @@ kernel void kernel_gemm_moe_q4_0_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl index e2574ae0187..b2bddf3f73a 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl @@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q4_1_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -200,6 +200,10 @@ kernel void kernel_gemm_moe_q4_1_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..ab8228d18ca --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl @@ -0,0 +1,283 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q4_k(q4, a_f16, scale, minv) \ + a_f16.s0 = (half)((float)(q4.s0 & 0x000F) * scale - minv); \ + a_f16.s1 = (half)((float)((q4.s0 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s2 = (half)((float)((q4.s0 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s3 = (half)((float)((q4.s0 & 0xF000) >> 12) * scale - minv); \ + a_f16.s4 = (half)((float)(q4.s1 & 0x000F) * scale - minv); \ + a_f16.s5 = (half)((float)((q4.s1 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s6 = (half)((float)((q4.s1 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s7 = (half)((float)((q4.s1 & 0xF000) >> 12) * scale - minv); \ + a_f16.s8 = (half)((float)(q4.s2 & 0x000F) * scale - minv); \ + a_f16.s9 = (half)((float)((q4.s2 & 0x00F0) >> 4) * scale - minv); \ + a_f16.sa = (half)((float)((q4.s2 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sb = (half)((float)((q4.s2 & 0xF000) >> 12) * scale - minv); \ + a_f16.sc = (half)((float)(q4.s3 & 0x000F) * scale - minv); \ + a_f16.sd = (half)((float)((q4.s3 & 0x00F0) >> 4) * scale - minv); \ + a_f16.se = (half)((float)((q4.s3 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sf = (half)((float)((q4.s3 & 0xF000) >> 12) * scale - minv); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q4_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half (next 16 elements, same sub-block scale) + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl index 3524cb1bdbd..d1a35d58bb2 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl @@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q5_0_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -202,6 +202,10 @@ kernel void kernel_gemm_moe_q5_0_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl index 5fc2a523234..90d345ecf51 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl @@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q5_1_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -204,6 +204,10 @@ kernel void kernel_gemm_moe_q5_1_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..13c26f6f3b6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl @@ -0,0 +1,288 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q5_k(qs5x16, qh5x16, a_f16, scale, m) \ + a_f16.s0 = (half)((float)(( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \ + a_f16.s1 = (half)((float)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.s2 = (half)((float)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.s3 = (half)((float)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.s4 = (half)((float)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.s5 = (half)((float)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.s6 = (half)((float)(((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.s7 = (half)((float)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m)); \ + a_f16.s8 = (half)((float)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m)); \ + a_f16.s9 = (half)((float)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.sa = (half)((float)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.sb = (half)((float)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.sc = (half)((float)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.sd = (half)((float)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.se = (half)((float)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m)); \ + a_f16.sf = (half)((float)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m)); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q5_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global uint * src0_qh, + __global uchar * src0_s, + __global half * src0_d, + __global half * src0_dm, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // qh is stored at sub-block granularity + uint qh_offset = row + sub * ne01 + expert_id * num_superblocks * 8 * ne01 + get_global_id(0); + uchar4 qhx32 = as_uchar4(src0_qh[qh_offset]); + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_k(as_ushort4(q4x16), qhx32.lo, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q5_k(as_ushort4(q4x16), qhx32.hi, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..85ccebec78c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl @@ -0,0 +1,267 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 + +#define dequantize_q6_k(qs16, qh16, a_f16, scale) \ + a_f16.s0 = (half)(((float)(( qs16.s0 & 0x000F) | ((uint)(( qh16 ) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s1 = (half)(((float)((( qs16.s0 >> 4) & 0x000F) | ((uint)(( qh16 >> 2) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s2 = (half)(((float)((( qs16.s0 >> 8) & 0x000F) | ((uint)(( qh16 >> 4) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s3 = (half)(((float)((( qs16.s0 >>12) & 0x000F) | ((uint)(( qh16 >> 6) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s4 = (half)(((float)(( qs16.s1 & 0x000F) | ((uint)(( qh16 >> 8) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s5 = (half)(((float)((( qs16.s1 >> 4) & 0x000F) | ((uint)(( qh16 >> 10) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s6 = (half)(((float)((( qs16.s1 >> 8) & 0x000F) | ((uint)(( qh16 >> 12) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s7 = (half)(((float)((( qs16.s1 >>12) & 0x000F) | ((uint)(( qh16 >> 14) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s8 = (half)(((float)(( qs16.s2 & 0x000F) | ((uint)(( qh16 >> 16) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s9 = (half)(((float)((( qs16.s2 >> 4) & 0x000F) | ((uint)(( qh16 >> 18) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sa = (half)(((float)((( qs16.s2 >> 8) & 0x000F) | ((uint)(( qh16 >> 20) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sb = (half)(((float)((( qs16.s2 >>12) & 0x000F) | ((uint)(( qh16 >> 22) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sc = (half)(((float)(( qs16.s3 & 0x000F) | ((uint)(( qh16 >> 24) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sd = (half)(((float)((( qs16.s3 >> 4) & 0x000F) | ((uint)(( qh16 >> 26) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.se = (half)(((float)((( qs16.s3 >> 8) & 0x000F) | ((uint)(( qh16 >> 28) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sf = (half)(((float)((( qs16.s3 >>12) & 0x000F) | ((uint)(( qh16 >> 30) & 0x3) << 4)) - 32.f) * scale); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q6_k_f32_ns( + __read_only image1d_buffer_t src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * 16; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; // 32-element group index + uint sb = sub / 8; // super-block index + uint j = sub % 8; // group within super-block + + // Load d for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + + // Load sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + uint qh_base = row + (sub * 2) * ne01 + expert_id * (num_superblocks * 16) * ne01 + get_global_id(0); + uint qh_first16 = src0_qh[qh_base]; + uint qh_second16 = src0_qh[qh_base + ne01]; + + // First half (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 ql nibbles (2 uints) from image + uint2 q4x16; + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantize first 16 elements (scale0) + dequantize_q6_k(as_ushort4(q4x16), qh_first16, reg_a, scale0); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q6_k(as_ushort4(q4x16), qh_second16, reg_a, scale1); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl index e4b44c1a56a..75129e20c65 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl @@ -82,6 +82,10 @@ __kernel void kernel_gemv_moe_mxfp4_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl index 6f4d3f53216..2d28db63ec5 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl @@ -37,6 +37,10 @@ __kernel void kernel_gemv_moe_q4_0_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl index 3739a215705..b98bdc0f12e 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q4_1_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..12464e9826e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl @@ -0,0 +1,155 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q4_k_to_fp32_packed8(ushort2 q4x8, float scale, float minv) { + float8 fp32x8; + fp32x8.s0 = (q4x8.s0 & 0x000F) * scale - minv; + fp32x8.s1 = ((q4x8.s0 & 0x00F0) >> 4) * scale - minv; + fp32x8.s2 = ((q4x8.s0 & 0x0F00) >> 8) * scale - minv; + fp32x8.s3 = ((q4x8.s0 & 0xF000) >> 12) * scale - minv; + fp32x8.s4 = (q4x8.s1 & 0x000F) * scale - minv; + fp32x8.s5 = ((q4x8.s1 & 0x00F0) >> 4) * scale - minv; + fp32x8.s6 = ((q4x8.s1 & 0x0F00) >> 8) * scale - minv; + fp32x8.s7 = ((q4x8.s1 & 0xF000) >> 12) * scale - minv; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_k_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s0), scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s1), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s2), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s3), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl index 938054cf982..b43613638a8 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q5_0_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl index f33a4ef2757..7a666006e68 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl @@ -39,6 +39,10 @@ __kernel void kernel_gemv_moe_q5_1_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..7d868d7abd9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl @@ -0,0 +1,160 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q5_k_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m); + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_k_f32_ns( + __global uint * src0_q, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // sub_block index = sb * 8 + j + uint expert_qh_offset = expert_id * num_superblocks * 8 * ne01; + uchar4 regQh = as_uchar4(src0_qh[expert_qh_offset + (sb * 8 + j) * ne01 + i01]); + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..c166bad5ba5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl @@ -0,0 +1,141 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q6_k_to_fp32_packed8(ushort2 ql8, ushort qh8, float d_scale) { + float8 fp32x8; + fp32x8.s0 = ((float)(( ql8.s0 & 0x000F) | ((uint)((qh8 ) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s1 = ((float)((( ql8.s0 >> 4) & 0x000F) | ((uint)((qh8 >> 2) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s2 = ((float)((( ql8.s0 >> 8) & 0x000F) | ((uint)((qh8 >> 4) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s3 = ((float)((( ql8.s0 >> 12)& 0x000F) | ((uint)((qh8 >> 6) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s4 = ((float)(( ql8.s1 & 0x000F) | ((uint)((qh8 >> 8) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s5 = ((float)((( ql8.s1 >> 4) & 0x000F) | ((uint)((qh8 >>10) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s6 = ((float)((( ql8.s1 >> 8) & 0x000F) | ((uint)((qh8 >>12) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s7 = ((float)((( ql8.s1 >> 12)& 0x000F) | ((uint)((qh8 >>14) & 0x3) << 4)) - 32.f) * d_scale; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q6_k_f32_ns( + __global uint * src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; // 8 sub-blocks of 32 per super-block + int scales_per_row = num_superblocks * 16; + + // Expert offsets in the transposed noshuffle layout + uint expert_ql_offset = expert_id * (ne00 / 8) * ne01; // 32 uints per super-block + uint expert_qh_offset = expert_id * (ne00 / 16) * ne01; // 16 uints per super-block + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; // super-block index + uint j = ib % 8; // 32-element group within super-block + + // Load d for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + + // Load 2 sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + // Load 4 uints of ql + uint ql_base = expert_ql_offset + (ib * 4) * ne01 + i01; + uint4 regQL; + regQL.s0 = src0_ql[ql_base]; + regQL.s1 = src0_ql[ql_base + ne01]; + regQL.s2 = src0_ql[ql_base + ne01 * 2]; + regQL.s3 = src0_ql[ql_base + ne01 * 3]; + + // Load 2 uints of qh + uint qh_base = expert_qh_offset + (ib * 2) * ne01 + i01; + uint2 regQH; + regQH.s0 = src0_qh[qh_base]; + regQH.s1 = src0_qh[qh_base + ne01]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s0), (ushort)(regQH.s0 & 0xFFFF), scale0); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s1), (ushort)(regQH.s0 >> 16), scale0); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s2), (ushort)(regQH.s1 & 0xFFFF), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s3), (ushort)(regQH.s1 >> 16), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 15443aa554a..15d231f70c0 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -13,6 +13,10 @@ #include // for qsort #include // for GGML_ASSERT +#ifdef GGML_USE_OPENMP +#include +#endif + #define GROUP_MAX_EPS 1e-15f #define GROUP_MAX_EPS_IQ3_XXS 1e-8f #define GROUP_MAX_EPS_IQ2_S 1e-8f @@ -3064,70 +3068,121 @@ void iq2xs_init_impl(enum ggml_type type) { } kmap_q2xs[index] = i; } - int8_t pos[8]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // The neighbour search runs in three passes: + // 1. Parallel: for each i, qsort and count its neighbours into n_per_i, + // and reduce the totals (num_neighbors, num_not_in_map). + // 2. Serial: prefix-sum n_per_i into offsets[], so each i has a + // pre-assigned slice of kneighbors_q2xs to write into. + // 3. Parallel: redo the qsort and write each i's neighbour list at + // offsets[i]. + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq2_data[gindex].neighbours = kneighbors_q2xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - kmap_q2xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q2xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q2xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q2xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int local_counter = offsets[i]; + kmap_q2xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q2xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q2xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq2xs_free_impl(enum ggml_type type) { @@ -3663,70 +3718,115 @@ void iq3xs_init_impl(int grid_size) { } kmap_q3xs[index] = i; } - int8_t pos[4]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // See explanation of parallelism in iq2xs_init_impl + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq3_data[gindex].neighbours = kneighbors_q3xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - kmap_q3xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q3xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q3xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q3xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int local_counter = offsets[i]; + kmap_q3xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq3xs_free_impl(int grid_size) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1cb8f563d85..d3805772183 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -199,6 +199,14 @@ static ggml_guid_t ggml_backend_rpc_guid() { return &guid; } +struct ggml_backend_rpc_device_context { + std::string endpoint; + uint32_t device; + std::string name; + std::string description; + uint64_t last_graph_uid; +}; + struct ggml_backend_rpc_buffer_type_context { std::string endpoint; uint32_t device; @@ -211,7 +219,6 @@ struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; - uint64_t last_graph_uid; }; struct ggml_backend_rpc_buffer_context { @@ -691,9 +698,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + ggml_backend_dev_t rpc_dev = ggml_backend_get_device(backend); + ggml_backend_rpc_device_context * rpc_dev_ctx = (ggml_backend_rpc_device_context *)rpc_dev->context; GGML_ASSERT(cgraph->n_nodes > 0); - bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid; + bool reuse = cgraph->uid != 0 && rpc_dev_ctx->last_graph_uid == cgraph->uid; if (reuse) { rpc_msg_graph_recompute_req request; request.device = rpc_ctx->device; @@ -701,7 +710,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); RPC_STATUS_ASSERT(status); } else { - rpc_ctx->last_graph_uid = cgraph->uid; + rpc_dev_ctx->last_graph_uid = cgraph->uid; std::vector input; serialize_graph(rpc_ctx->device, cgraph, input); auto sock = get_socket(rpc_ctx->endpoint); @@ -770,7 +779,6 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { /* .endpoint = */ endpoint, /* .device = */ device, /* .name = */ dev_name, - /* .last_graph_uid = */ 0, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -1757,15 +1765,6 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } } -// device interface - -struct ggml_backend_rpc_device_context { - std::string endpoint; - uint32_t device; - std::string name; - std::string description; -}; - static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; @@ -1947,10 +1946,11 @@ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { std::string dev_name = "RPC" + std::to_string(dev_id); std::string dev_desc = std::string(endpoint); ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .device = */ ind, - /* .name = */ dev_name, - /* .description = */ dev_desc + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc, + /* .last_graph_uid = */ 0, }; ggml_backend_dev_t dev = new ggml_backend_device { diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 8f44c6ed080..180de92202d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -39,6 +39,18 @@ if (WIN32) set(CMAKE_CXX_COMPILER "icx") set(CMAKE_CXX_COMPILER_ID "IntelLLVM") endif() + # Level Zero SDK path for Windows (only when GGML_SYCL_SUPPORT_LEVEL_ZERO is enabled) + if(GGML_SYCL_SUPPORT_LEVEL_ZERO) + if(DEFINED ENV{LEVEL_ZERO_V1_SDK_PATH}) + set(LEVEL_ZERO_V1_SDK_PATH $ENV{LEVEL_ZERO_V1_SDK_PATH}) + if(EXISTS "${LEVEL_ZERO_V1_SDK_PATH}") + target_include_directories(ggml-sycl PRIVATE "${LEVEL_ZERO_V1_SDK_PATH}/include") + set(LEVEL_ZERO_V1_SDK_LIB_PATH "${LEVEL_ZERO_V1_SDK_PATH}/lib") + else() + message(WARNING "LEVEL_ZERO_V1_SDK_PATH set but folder not found: ${LEVEL_ZERO_V1_SDK_PATH}") + endif() + endif() + endif() endif() macro(detect_and_find_package package_name) @@ -93,6 +105,23 @@ endif() target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") +message(STATUS "GGML_SYCL_SUPPORT_LEVEL_ZERO ${GGML_SYCL_SUPPORT_LEVEL_ZERO}") +if (GGML_SYCL_SUPPORT_LEVEL_ZERO) + # Link against Level Zero loader for direct device memory allocation. + # Avoids sycl::malloc_device triggering DMA-buf/TTM system RAM staging + # in the xe kernel driver during multi-GPU inference. + find_path(LEVEL_ZERO_INCLUDE_DIR level_zero/ze_api.h HINTS ${ONEAPI_ROOT}/include ${LEVEL_ZERO_V1_SDK_PATH}/include) + find_library(ZE_LOADER_LIB ze_loader HINTS ${ONEAPI_ROOT}/lib ${LEVEL_ZERO_V1_SDK_LIB_PATH} ENV LD_LIBRARY_PATH) + if(ZE_LOADER_LIB AND LEVEL_ZERO_INCLUDE_DIR) + target_link_libraries(ggml-sycl PRIVATE ${ZE_LOADER_LIB}) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_SUPPORT_LEVEL_ZERO) + message(STATUS "Level Zero loader found: ${ZE_LOADER_LIB}") + message(STATUS "Level Zero headers found: ${LEVEL_ZERO_INCLUDE_DIR}") + else() + message(WARNING "Level Zero loader or headers not found, Level Zero support disabled") + endif() +endif() + # Link against oneDNN set(GGML_SYCL_DNNL 0) if(GGML_SYCL_DNN) diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index 05fd5ef46c7..ae08abad81b 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -11,6 +11,10 @@ // #include "common.hpp" +#include +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include +#endif #include "ggml-backend-impl.h" #include "ggml-impl.h" @@ -55,6 +59,20 @@ bool gpu_has_xmx(sycl::device &dev) { return dev.has(sycl::aspect::ext_intel_matrix); } +static int ggml_sycl_get_env(const char *env_name, int default_val) { + char *user_device_string = getenv(env_name); + int user_number = default_val; + + unsigned n; + if (user_device_string != NULL && + sscanf(user_device_string, " %u", &n) == 1) { + user_number = (int)n; + } else { + user_number = default_val; + } + return user_number; +} + int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { const int64_t max_range = std::numeric_limits::max(); int64_t sycl_down_blk_size = block_size; @@ -66,6 +84,61 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block return sycl_down_blk_size; } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_use_level_zero_device_alloc(sycl::queue &q) { + return ggml_sycl_get_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1) && + q.get_device().is_gpu() && + q.get_backend() == sycl::backend::ext_oneapi_level_zero; +} +#endif + +// Use Level Zero zeMemAllocDevice to avoid sycl::malloc_device triggering +// DMA-buf/TTM system RAM staging in the xe kernel driver during multi-GPU inference. +// The decision is made from the queue and runtime env because large buffers can be +// allocated before ggml_check_sycl() initializes g_ggml_sycl_enable_level_zero. +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + void *ptr = nullptr; + auto ze_ctx = sycl::get_native(q.get_context()); + auto ze_dev = sycl::get_native(q.get_device()); +#ifdef ZE_RELAXED_ALLOCATION_LIMITS_EXP_NAME + ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = { + ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC, + nullptr, + ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE, + }; + ze_device_mem_alloc_desc_t alloc_desc = { + ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, + &relaxed_desc, + 0, + 0, + }; +#else + ze_device_mem_alloc_desc_t alloc_desc = {ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, nullptr, 0, 0}; +#endif + ze_result_t r = zeMemAllocDevice(ze_ctx, &alloc_desc, size, 64, ze_dev, &ptr); + if (r == ZE_RESULT_SUCCESS && ptr) { + return ptr; + } + return nullptr; + } +#endif + return sycl::malloc_device(size, q); +} + +void ggml_sycl_free_device(void *ptr, sycl::queue &q) { + if (!ptr) return; +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + auto ze_ctx = sycl::get_native(q.get_context()); + zeMemFree(ze_ctx, ptr); + return; + } +#endif + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, q))); +} + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams) { for (int i = 0; i < ggml_sycl_info().device_count; ++i) { for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { @@ -75,8 +148,7 @@ void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector str } if (extra->data_device[i] != nullptr && streams.size()>0) { ggml_sycl_set_device(i); - SYCL_CHECK( - CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i])))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(extra->data_device[i], *(streams[i])))); } } delete extra; diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index eec36e8db9a..6d19538215e 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -238,6 +238,8 @@ struct ggml_sycl_device_info { std::array default_tensor_split = {}; int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0}; + + bool ext_oneapi_level_zero = true; // sycl::backend::ext_oneapi_level_zero used by all enumerated GPU devices }; const ggml_sycl_device_info & ggml_sycl_info(); @@ -310,6 +312,10 @@ struct ggml_tensor_extra_gpu { optimize_feature optimized_feature; }; +extern int g_ggml_sycl_enable_level_zero; +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q); +void ggml_sycl_free_device(void *ptr, sycl::queue &q); + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams={}); namespace sycl_ex = sycl::ext::oneapi::experimental; diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 5577bf73b28..4ae431a962e 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -3,6 +3,13 @@ #include "dequantize.hpp" #include "presets.hpp" +#if defined(__INTEL_LLVM_COMPILER) + #if __has_include() + #include + #define GGML_SYCL_DMMV_HAS_BF16 + #endif +#endif + static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -11,6 +18,16 @@ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat v.y() = x[ib + iqs + 1]; } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const sycl::ext::oneapi::bfloat16 *x = (const sycl::ext::oneapi::bfloat16 *)vx; + + // automatic bfloat16 -> float type cast if dfloat == float + v.x() = x[ib + iqs + 0]; + v.y() = x[ib + iqs + 1]; +} +#endif + static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const float * x = (const float *) vx; @@ -217,6 +234,28 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, } } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_mul_mat_vec_bf16_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + // The qk=1 kernel iterates with stride 2*GGML_SYCL_DMMV_X, so ncols must be a + // multiple of that — not just GGML_SYCL_DMMV_X — to avoid out-of-bounds reads. + GGML_ASSERT(ncols % (2*GGML_SYCL_DMMV_X) == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec<1, 1, convert_bf16>(vx, y, dst, ncols, + nrows, item_ct1); + }); + } +} +#endif + /* DPCT1110:4: The total declared local variable size in device function dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register @@ -1497,7 +1536,8 @@ void ggml_sycl_op_dequantize_mul_mat_vec( bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 || + src0->type == GGML_TYPE_BF16; if (src1_convert_f16) { scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, @@ -1565,6 +1605,11 @@ void ggml_sycl_op_dequantize_mul_mat_vec( case GGML_TYPE_F16: convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; +#ifdef GGML_SYCL_DMMV_HAS_BF16 + case GGML_TYPE_BF16: + convert_mul_mat_vec_bf16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; +#endif default: printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index ebc587524bf..9c2449aba0c 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -6,7 +6,7 @@ #include -template +template void gated_delta_net_sycl(const float * q, const float * k, const float * v, @@ -28,7 +28,8 @@ void gated_delta_net_sycl(const float * q, int64_t sb3, const sycl::uint3 neqk1_magic, const sycl::uint3 rq3_magic, - float scale) { + float scale, + int K) { auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); const uint32_t h_idx = item_ct1.get_group(2); const uint32_t sequence = item_ct1.get_group(1); @@ -43,9 +44,13 @@ void gated_delta_net_sycl(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v; @@ -55,9 +60,13 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[col * S_v + i]; + s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -131,17 +140,32 @@ void gated_delta_net_sycl(const float * q, } attn_data += S_v * H; - } + // Write state back to global memory + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } + } + + if constexpr (!keep_rs_t) { +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net(const float * q_d, const float * k_d, const float * v_d, @@ -165,6 +189,7 @@ static void launch_gated_delta_net(const float * q_d, int64_t neqk1, int64_t rq3, float scale, + int K, dpct::queue_ptr stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size; @@ -182,9 +207,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 16; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, - sb3, neqk1_magic, rq3_magic, scale); + sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -193,9 +218,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 32; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, - sb3, neqk1_magic, rq3_magic, scale); + sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -204,9 +229,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 64; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl( + gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, - sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -216,9 +241,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 128; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl( + gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, - sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -290,14 +315,30 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dpct::queue_ptr stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 57cc4ffb6f7..b3fbb621196 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -30,6 +30,10 @@ #include #include +#include +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include +#endif #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC # include #endif @@ -68,6 +72,8 @@ int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; +int g_ggml_sycl_use_async_mem_op_requested = 1; +int g_ggml_sycl_enable_level_zero = 0; int g_ggml_sycl_enable_flash_attention = 1; @@ -92,7 +98,7 @@ static ggml_sycl_device_info ggml_sycl_init() { for (int i = 0; i < info.device_count; ++i) { info.devices[i].vmm = 0; dpct::device_info prop; - sycl::device device = dpct::dev_mgr::instance().get_device(i); + auto & device = dpct::dev_mgr::instance().get_device(i); SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, device))); @@ -111,6 +117,12 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); info.devices[i].hw_info = get_device_hw_info(&device); + // Only check GPU devices; CPU devices use OpenCL and would otherwise + // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. + if (device.is_gpu() && device.default_queue().get_backend() != sycl::backend::ext_oneapi_level_zero) { + GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); + info.ext_oneapi_level_zero = false; + } } for (int id = 0; id < info.device_count; ++id) { @@ -223,6 +235,11 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero); +#else + g_ggml_sycl_enable_level_zero = 0; +#endif #ifdef SYCL_FLASH_ATTN g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); @@ -253,6 +270,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n"); #endif +#if defined(GGML_SYCL_SUPPORT_LEVEL_ZERO) + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n"); +#endif GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); @@ -262,12 +284,19 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); #endif +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: %d\n", g_ggml_sycl_enable_level_zero); +#else + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: Level Zero disabled by compile flag\n"); +#endif #if GGML_SYCL_DNNL GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); #endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); + g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1); + GGML_LOG_INFO(" GGML_SYCL_USE_ASYNC_MEM_OP: %d\n", g_ggml_sycl_use_async_mem_op_requested); #ifdef SYCL_FLASH_ATTN GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention); @@ -283,11 +312,11 @@ static void ggml_check_sycl() try { fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); #endif */ - // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be - // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in - // other places. + // Async USM allocation/free is also useful outside the graph path: it avoids the host waits in the reorder + // staging path while preserving queue ordering semantics. Graph support still depends on the extension being + // available, but it no longer needs to control the non-graph fast path. #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC - g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph; + g_ggml_sycl_use_async_mem_op = g_ggml_sycl_use_async_mem_op_requested || !g_ggml_sycl_disable_graph; if (g_ggml_sycl_use_async_mem_op) { for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) { if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) { @@ -371,7 +400,7 @@ struct ggml_backend_sycl_buffer_context { ~ggml_backend_sycl_buffer_context() { if (dev_ptr != nullptr) { ggml_sycl_set_device(device); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(dev_ptr, *stream))); } //release extra used by tensors @@ -504,8 +533,43 @@ catch (sycl::exception const &exc) { std::exit(1); } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_is_l0_discrete_gpu(sycl::queue &q) { + if (!q.get_device().is_gpu() || q.get_backend() != sycl::backend::ext_oneapi_level_zero) { + return false; + } + + ze_device_handle_t ze_dev = sycl::get_native(q.get_device()); + ze_device_properties_t props = {}; + props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + ze_result_t r = zeDeviceGetProperties(ze_dev, &props); + return r == ZE_RESULT_SUCCESS && !(props.flags & ZE_DEVICE_PROPERTY_FLAG_INTEGRATED); +} +#endif + static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, const void *ptr_src, size_t size) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + // Use Level Zero direct copy for dGPU-to-dGPU transfers. + const bool l0_copy_supported = + ggml_sycl_is_l0_discrete_gpu(q_dst) && ggml_sycl_is_l0_discrete_gpu(q_src); + if (g_ggml_sycl_enable_level_zero && l0_copy_supported) { + auto ze_ctx = sycl::get_native(q_dst.get_context()); + auto ze_dev = sycl::get_native(q_dst.get_device()); + ze_command_queue_desc_t cq_desc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, nullptr, 0, 0, + 0, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL}; + ze_command_list_handle_t cl; + ze_result_t r = zeCommandListCreateImmediate(ze_ctx, ze_dev, &cq_desc, &cl); + if (r == ZE_RESULT_SUCCESS) { + r = zeCommandListAppendMemoryCopy(cl, ptr_dst, ptr_src, size, nullptr, 0, nullptr); + zeCommandListDestroy(cl); + if (r == ZE_RESULT_SUCCESS) { + return; + } + } + } +#endif + // Host-staged copy char *host_buf = (char *)malloc(size); q_src.memcpy(host_buf, (const char *)ptr_src, size).wait(); q_dst.memcpy((char *)ptr_dst, host_buf, size).wait(); @@ -675,8 +739,7 @@ ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 void * dev_ptr; - SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)ggml_sycl_malloc_device(size, *stream))); if (!dev_ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size); return nullptr; @@ -917,18 +980,10 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - // FIXME: do not crash if SYCL Buffer alloc fails - // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_sycl_set_device(i); const queue_ptr stream = ctx->streams[i]; char * buf; - /* - DPCT1009:208: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)ggml_sycl_malloc_device(size, *stream))); if (!buf) { char err_buf[1024]; snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size); @@ -1306,7 +1361,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { ggml_sycl_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(b.ptr, *qptr))); pool_size -= b.size; } } @@ -1374,9 +1429,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); - SYCL_CHECK( - CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( - look_ahead_size, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *)ggml_sycl_malloc_device(look_ahead_size, *qptr))); if (!ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size); return nullptr; @@ -1404,7 +1457,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } } GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(ptr, *qptr))); pool_size -= size; } }; @@ -2325,21 +2378,25 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); + { + const int64_t gemm_flops = (int64_t)row_diff * src1_ncols * ne10; + const bool use_mkl_direct = gemm_flops < 256 * 256 * 256; #if GGML_SYCL_DNNL - if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, - DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), - dst_dd_i, DnnlGemmWrapper::to_dt(), stream); - } - else + if (!g_ggml_sycl_disable_dnn && !use_mkl_direct) { + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, + DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else #endif - { - const float alpha = 1.0f; - const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, - src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, - dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } } } GGML_UNUSED(dst); @@ -3388,6 +3445,7 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_F16: + case GGML_TYPE_BF16: return true; default: return false; @@ -3405,7 +3463,7 @@ static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - return sycl::malloc(size, *stream, sycl::usm::alloc::device); + return ggml_sycl_malloc_device(size, *stream); } static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { @@ -3419,7 +3477,7 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - sycl::free(ptr, *stream); + ggml_sycl_free_device(ptr, *stream); } // RAII wrapper for temporary reorder buffers with optional host memory fallback. @@ -3751,8 +3809,13 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + // The F16/BF16 qk=1 kernel iterates with stride 2*DMMV_X, requiring ne[0] to be + // a multiple of 2*DMMV_X. Quantized types use block-structured kernels that only + // need ne[0] % DMMV_X == 0. + const int64_t dmmv_x_required = (src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F16) ? + 2*GGML_SYCL_DMMV_X : GGML_SYCL_DMMV_X; return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + src0->ne[0] % dmmv_x_required == 0 && src1->ne[1] == 1; } static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -3856,35 +3919,17 @@ struct mmid_row_mapping { __dpct_inline__ static void k_copy_src1_to_contiguous( const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, - int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping, - const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + const mmid_row_mapping *__restrict__ row_mapping, int64_t ne11, int64_t ne10, size_t nb11, size_t nb12, - const sycl::nd_item<3> &item_ct1, int &src1_row) { - int32_t iid1 = item_ct1.get_group(2); - int32_t id = item_ct1.get_group(1); - - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + const sycl::nd_item<3> &item_ct1) { + const int32_t src1_row = item_ct1.get_group(2); - if (row_id_i != i02) { - return; - } + const int32_t iid1 = row_mapping[src1_row].i2; + const int32_t id = row_mapping[src1_row].i1; const int64_t i11 = id % ne11; const int64_t i12 = iid1; - if (item_ct1.get_local_id(2) == 0) { - src1_row = - dpct::atomic_fetch_add( - cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - /* - DPCT1065:194: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. - */ - item_ct1.barrier(); - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); @@ -3959,6 +4004,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused( src1_row_stride, stream); } +// counting sort of the routed rows by expert id (row_id_i, as chosen by the router): +// builds a projection of a memory layout where each expert's slice is contiguous +static void mmid_counting_sort_rows( + const ggml_tensor * ids, const char * ids_host, + int64_t n_ids, int64_t n_as, int64_t n_routed_rows, + std::vector & expert_counts, + std::vector & expert_row_offsets, + std::vector & routed_row_src) { + + // frequencies: how many routed rows each expert "owns" + expert_counts.assign(n_as, 0); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + expert_counts[row_id_i]++; + } + } + + // where each expert's slice starts (row indices) and the previous ends + expert_row_offsets.assign(n_as + 1, 0); + for (int64_t i02 = 0; i02 < n_as; i02++) { + expert_row_offsets[i02 + 1] = expert_row_offsets[i02] + expert_counts[i02]; + } + + std::vector expert_row_next = expert_row_offsets; + routed_row_src.resize(n_routed_rows); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + // find and validate the next free row for a given expert (row_id_i) + const int64_t routed_row = expert_row_next[row_id_i]++; + GGML_ASSERT(routed_row >= expert_row_offsets[row_id_i]); + GGML_ASSERT(routed_row < expert_row_offsets[row_id_i + 1]); + routed_row_src[routed_row] = {(int32_t) id, (int32_t) iid1}; + } + } +} + static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); @@ -4037,99 +4123,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + // how many "owned" routed rows to pass to each expert + std::vector expert_row_counts; + // where each expert's slice starts and the previous ends (row indices, right-exclusive) + std::vector expert_row_offsets; + // the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous + std::vector routed_row_src; - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows, + expert_row_counts, expert_row_offsets, routed_row_src); - if (row_id_i != i02) { - continue; - } + ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), n_routed_rows); + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memcpy(dev_row_mapping.get(), routed_row_src.data(), n_routed_rows*sizeof(mmid_row_mapping)))); - num_src1_rows++; - } - } + const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; + assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + char *__restrict src1_contiguous_get = + src1_contiguous.get(); + mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_src1_to_contiguous( + src1_original, src1_contiguous_get, + dev_row_mapping_get, + ne11, ne10, nb11, nb12, + item_ct1); + }); + }); + } + + for (int64_t i02 = 0; i02 < n_as; i02++) { + const int64_t num_src1_rows = expert_row_counts[i02]; if (num_src1_rows == 0) { continue; } - - ggml_sycl_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); - - const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; - assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); - - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); - sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor src1_row_acc(cgh); - - char *__restrict src1_contiguous_get = - src1_contiguous.get(); - int *__restrict dev_cur_src1_row_get = - dev_cur_src1_row.get(); - mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - size_t ids_nb_ct6 = ids->nb[1]; - size_t ids_nb_ct7 = ids->nb[0]; - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_src1_to_contiguous( - src1_original, src1_contiguous_get, - dev_cur_src1_row_get, - dev_row_mapping_get, ids_dev, i02, - ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12, - item_ct1, src1_row_acc); - }); - }); - } + const int64_t expert_row_offset = expert_row_offsets[i02]; src0_row.data = src0_original + i02*nb02; GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb1 == sizeof(float)*ne0); + src1_row.data = src1_contiguous.get() + expert_row_offset*nb11; src1_row.ne[1] = num_src1_rows; src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.data = dst_contiguous.get() + expert_row_offset*nb1; dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + } - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); - sycl::range<3> grid_dims(1, 1, num_src1_rows); - stream->submit([&](sycl::handler &cgh) { - const char *__restrict dst_contiguous_get = - dst_contiguous.get(); - const mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_dst_from_contiguous(dst_original, - dst_contiguous_get, - dev_row_mapping_get, - ne0, nb1, nb2, item_ct1); - }); - }); - } + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + const char *__restrict dst_contiguous_get = + dst_contiguous.get(); + const mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_dst_from_contiguous(dst_original, + dst_contiguous_get, + dev_row_mapping_get, + ne0, nb1, nb2, item_ct1); + }); + }); } } } diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index d7770047424..16b2d65d271 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -85,6 +85,32 @@ static __dpct_inline__ int get_int_from_uint8_aligned( (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } +static __dpct_inline__ int byte_sub_4(const int a, const int b) { + const uint32_t ua = static_cast(a); + const uint32_t ub = static_cast(b); + return static_cast(((ua | 0x80808080u) - ub) ^ 0x80808080u); +} + +static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq_scalar( + const int vl, const int vh, const int u0, const int u1, const int8_t sc0, + const int8_t sc1, const float d, const float d80, const float d81) { + static_assert(QR6_K == 2, "q6_K MMVQ scalar fast path assumes QR6_K == 2"); + + const int vil0 = (vl >> 0) & 0x0F0F0F0F; + const int vih0 = ((vh >> 0) << 4) & 0x30303030; + const int vi0 = byte_sub_4(vil0 | vih0, 0x20202020); + + const int vil1 = (vl >> 4) & 0x0F0F0F0F; + const int vih1 = ((vh >> 4) << 4) & 0x30303030; + const int vi1 = byte_sub_4(vil1 | vih1, 0x20202020); + + const float sumf = + d80 * (dpct::dp4a(vi0, u0, 0) * sc0) + + d81 * (dpct::dp4a(vi1, u1, 0) * sc1); + + return d * sumf; +} + static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4, const uint8_t *values, int &val1, int &val2) { @@ -279,24 +305,8 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, const int *__restrict__ u, const int8_t *__restrict__ scales, const float &d, const float *__restrict__ d8) { - - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4*i]; - - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4*i)) << 4) & 0x30303030; - - const int vi = dpct::vectorized_binary( - (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32 - - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d*sumf; + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called @@ -542,23 +552,8 @@ template <> struct reorder_vec_dot_q_sycl { __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u, const int8_t * __restrict__ scales, const float d, const float * __restrict__ d8) { - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4 * i]; - - const int vil = (vl >> (4 * i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4 * i)) << 4) & 0x30303030; - - const int vi = dpct::vectorized_binary((vil | vih), 0x20202020, - dpct::sub_sat()); // vi = (vil | vih) - 32 - - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d * sumf; + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, @@ -579,16 +574,15 @@ template <> struct reorder_vec_dot_q_sycl { const int8_t * scs = scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; + const int u0 = get_int_from_int8_aligned( + q8_1_quant_ptr + bq8_offset * QK8_1, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + q8_1_quant_ptr + (bq8_offset + 2) * QK8_1, iqs % QI8_1); + const float d80 = (*(q8_1_ds + bq8_offset + 0))[0]; + const float d81 = (*(q8_1_ds + bq8_offset + 2))[0]; -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1); - const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i); - d8[i] = ds_values[0]; - } - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scs[0], scs[4], *d, d80, d81); } }; #define VDR_Q4_0_Q8_1_MMVQ 2 @@ -1167,16 +1161,15 @@ vec_dot_q6_K_q8_1(const void *__restrict__ vbq, const int8_t * scales = bq6_K->scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + 2 * i].ds[0]; - } + const int u0 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 0].qs, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 2].qs, iqs % QI8_1); + const float d80 = bq8_1[bq8_offset + 0].ds[0]; + const float d81 = bq8_1[bq8_offset + 2].ds[0]; - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scales[0], scales[4], bq6_K->d, d80, d81); } diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 715a263a6d0..65785ae4566 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,6 +8,11 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +if (DEFINED ENV{VULKAN_SDK}) + list(APPEND CMAKE_PREFIX_PATH "$ENV{VULKAN_SDK}") +endif() +find_package(SPIRV-Headers CONFIG REQUIRED) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files add_definitions(/MP) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a0a556206d5..aa289220a90 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -49,7 +49,6 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include #include #include -#include #include #include #include @@ -500,6 +499,12 @@ static constexpr std::initializer_list topk_moe_late_softmax { GGM GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; +// Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder +// pass so it keeps the chain contiguous and by the dispatcher to detect the fusion. +static constexpr std::initializer_list snake_pattern { GGML_OP_MUL, GGML_OP_SIN, + GGML_OP_SQR, GGML_OP_MUL, + GGML_OP_ADD }; + //node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] //node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] //node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] @@ -760,8 +765,8 @@ struct vk_device_struct { vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32; @@ -847,6 +852,9 @@ struct vk_device_struct { vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; + vk_pipeline pipeline_snake_f32; + vk_pipeline pipeline_snake_f16; + vk_pipeline pipeline_snake_bf16; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; @@ -855,6 +863,8 @@ struct vk_device_struct { vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; + vk_pipeline pipeline_ssm_conv_silu_f32; + vk_pipeline pipeline_ssm_conv_bias_silu_f32; vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_sgd_f32; std::map pipeline_conv2d_f32[CONV_SHAPE_COUNT]; @@ -1353,6 +1363,8 @@ struct vk_op_rope_push_constants { uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t a_offset; + uint32_t d_offset; }; static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128"); @@ -1472,6 +1484,11 @@ struct vk_op_conv_transpose_1d_push_constants { int32_t s0; }; +struct vk_op_snake_push_constants { + uint32_t ne0; + uint32_t ne1; +}; + struct vk_op_pool2d_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; @@ -1506,6 +1523,7 @@ struct vk_op_gated_delta_net_push_constants { uint32_t sb1, sb2, sb3; uint32_t neq1, rq3; float scale; + uint32_t K; }; struct vk_op_ssm_scan_push_constants { @@ -3954,13 +3972,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ - if (device->mul_mat ## ID ## _l[TYPE]) { \ + if (device->mul_mat ## ID ## _l_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _m[TYPE]) { \ + if (device->mul_mat ## ID ## _m_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _s[TYPE]) { \ + if (device->mul_mat ## ID ## _s_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ @@ -4131,11 +4149,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l[TYPE]) \ + if (device->mul_mat ## ID ## _l_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m[TYPE]) \ + if (device->mul_mat ## ID ## _m_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s[TYPE]) \ + if (device->mul_mat ## ID ## _s_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); @@ -4568,6 +4586,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4576,6 +4595,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4839,6 +4859,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); @@ -4900,7 +4924,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); } - ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -5716,12 +5742,12 @@ static vk_device ggml_vk_get_device(size_t idx) { break; } - device->mul_mat_l_int[i] = true; - device->mul_mat_m_int[i] = true; - device->mul_mat_s_int[i] = true; - device->mul_mat_id_l_int[i] = true; - device->mul_mat_id_m_int[i] = true; - device->mul_mat_id_s_int[i] = true; + device->mul_mat_l_int[i] = device->mul_mat_l[i]; + device->mul_mat_m_int[i] = device->mul_mat_m[i]; + device->mul_mat_s_int[i] = device->mul_mat_s[i]; + device->mul_mat_id_l_int[i] = device->mul_mat_id_l[i]; + device->mul_mat_id_m_int[i] = device->mul_mat_id_m[i]; + device->mul_mat_id_s_int[i] = device->mul_mat_id_s[i]; } @@ -7538,6 +7564,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_bf16_f32; + } else { + return ctx->device->pipeline_cpy_bf16_f32; + } + } if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_i32; @@ -9936,7 +9969,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SSM_CONV: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_ssm_conv_f32; + switch (ctx->num_additional_fused_ops) { + case 0: return ctx->device->pipeline_ssm_conv_f32; + case 1: return ctx->device->pipeline_ssm_conv_silu_f32; + case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32; + default: return nullptr; + } } return nullptr; case GGML_OP_OPT_STEP_ADAMW: @@ -10117,6 +10155,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + template static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; @@ -10767,6 +10814,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -10775,6 +10823,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const uint32_t K = (uint32_t)src_state->ne[1]; + const uint32_t s_off = S_v * H * n_tokens * n_seqs; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); @@ -10808,7 +10859,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, - scale + scale, + K }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, @@ -10872,11 +10924,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, pc, elements); } -static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; +static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { + ggml_tensor * conv = cgraph->nodes[node_idx]; + const ggml_tensor * src0 = conv->src[0]; + const ggml_tensor * src1 = conv->src[1]; + + // Pick the destination tensor (last node in the fused chain) and the optional bias. + // Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu. + ggml_tensor * dst = conv; + const ggml_tensor * bias = nullptr; + + if (ctx->num_additional_fused_ops == 1) { + dst = cgraph->nodes[node_idx + 1]; // silu + } else if (ctx->num_additional_fused_ops == 2) { + ggml_tensor * add = cgraph->nodes[node_idx + 1]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + dst = cgraph->nodes[node_idx + 2]; // silu + } + + // The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused. + const ggml_tensor * src2 = bias ? bias : src0; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, { (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], (uint32_t)src1->nb[1], (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], @@ -11239,6 +11308,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * (uint32_t)src0->ne[2], nb01, nb02, nb03, nb11, nb12, nb13, + 0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets }; return rope; @@ -11334,6 +11404,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, GGML_ASSERT(buf[i] != nullptr); } + // a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned. + // Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset. + pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type); + offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1); + std::array elements; elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; @@ -12053,6 +12128,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p)); } +// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b. +// Match the naive mul -> sin -> sqr -> mul -> add chain and run the +// dedicated kernel directly. The pattern is validated by +// ggml_vk_can_fuse_snake before this call. +static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + vk_pipeline pipeline = nullptr; + switch (x->type) { + case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break; + case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break; + case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break; + default: GGML_ABORT("unsupported type"); + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x); + vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a); + vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add); + + vk_op_snake_push_constants pc{}; + pc.ne0 = static_cast(x->ne[0]); + pc.ne1 = static_cast(x->ne[1]); + + std::array elements = { pc.ne0, pc.ne1, 1 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements); +} + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { uint32_t op = static_cast(dst->op_params[0]); const int32_t k1 = dst->op_params[1]; @@ -13261,7 +13375,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_MUL: - ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + if (ctx->num_additional_fused_ops) { + ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx); + } else { + ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + } break; case GGML_OP_DIV: @@ -13551,7 +13669,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SSM_CONV: - ggml_vk_ssm_conv(ctx, compute_ctx, node); + ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx); break; @@ -14448,6 +14566,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g return true; } +// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2. +static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx, int num_extra) { + const ggml_tensor * conv = cgraph->nodes[node_idx]; + if (conv->op != GGML_OP_SSM_CONV) { + return false; + } + + const ggml_tensor * silu = nullptr; + const ggml_tensor * bias = nullptr; + + if (num_extra == 1) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) { + return false; + } + silu = cgraph->nodes[node_idx + 1]; + } else if (num_extra == 2) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) { + return false; + } + const ggml_tensor * add = cgraph->nodes[node_idx + 1]; + silu = cgraph->nodes[node_idx + 2]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + // bias must be channel-wise (one element per channel of the conv output) + if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) { + return false; + } + if (add->type != GGML_TYPE_F32) { + return false; + } + // The shader doesn't apply per-tensor offsets, so reject misaligned bias. + if (get_misalign_bytes(ctx, bias) != 0) { + return false; + } + } else { + return false; + } + + if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) { + return false; + } + if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + // The shader writes to the fused dst using its own strides, but the push constants don't + // carry a per-tensor offset, so the binding must be naturally aligned. + if (get_misalign_bytes(ctx, silu) != 0) { + return false; + } + return true; +} + static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, topk_moe_mode mode) { @@ -14578,6 +14752,65 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } +// Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add. +// Verifies the chain shape, the closure x_in_add == x_in_mul0, and that +// the broadcast operands a and inv_b share a [1, C] layout. +static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + GGML_UNUSED(ctx); + if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) { + return false; + } + + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + const ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + if (x_in_add != x) { + return false; + } + if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) { + return false; + } + // Shader bindings: data_a is A_TYPE so it follows x's precision, while + // data_b and data_c are hardcoded float, so the broadcast operands must + // be F32 regardless of x's type. + if (a->type != GGML_TYPE_F32) return false; + if (inv_b->type != GGML_TYPE_F32) return false; + // Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline). + if (mul0->type != x->type) return false; + if (sin_node->type != x->type) return false; + if (sqr->type != x->type) return false; + if (mul1->type != x->type) return false; + if (add->type != x->type) return false; + if (!ggml_are_same_shape(a, inv_b)) { + return false; + } + if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) { + return false; + } + // Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b + // must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader. + if (x->ne[2] != 1 || x->ne[3] != 1) return false; + if (add->ne[2] != 1 || add->ne[3] != 1) return false; + if (a->ne[2] != 1 || a->ne[3] != 1) return false; + if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false; + // Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1], + // so every operand must be contiguous. + if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) || + !ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) { + return false; + } + return true; +} + // Check whether the tensors overlap in memory. // Fusions can potentially overwrite src tensors in ways that are not prevented // by ggml-alloc. If the fusion src is being applied in a way that's elementwise @@ -14864,6 +15097,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // they are overwritten, and one workgroup per row. So close enough. op_srcs_fused_elementwise[0] = true; op_srcs_fused_elementwise[1] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) { + ctx->num_additional_fused_ops = 2; + fusion_string = "SSM_CONV_BIAS_SILU"; + // ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs. + // The downstream add and silu are elementwise on the conv output. + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) { + ctx->num_additional_fused_ops = 1; + fusion_string = "SSM_CONV_SILU"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { @@ -14872,6 +15118,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg op_srcs_fused_elementwise[0] = false; op_srcs_fused_elementwise[1] = false; op_srcs_fused_elementwise[2] = false; + } else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 4; + fusion_string = "SNAKE"; + // elementwise=true: snake.comp is safe under exact aliasing because each + // thread reads data_x[idx] into a register before writing data_d[idx] + // with a data dependency on that register. The overlap check still + // rejects partial overlaps (different base or size). + std::fill_n(op_srcs_fused_elementwise, 5, true); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { @@ -15162,6 +15416,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_late_softmax)) { continue; } + if (keep_pattern(snake_pattern)) { + continue; + } // First, grab the next unused node. current_set.push_back(first_unused); @@ -15184,7 +15441,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (match_pattern(topk_moe_early_softmax_norm, j) || match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || - match_pattern(topk_moe_late_softmax, j)) { + match_pattern(topk_moe_late_softmax, j) || + match_pattern(snake_pattern, j)) { continue; } bool ok = true; @@ -15195,7 +15453,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) { ok = false; break; } @@ -15278,6 +15538,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } } } + // SSM_CONV + ADD + UNARY: pull the consuming UNARY forward + if (j > 0 && + graph->nodes[j]->op == GGML_OP_ADD && + graph->nodes[j-1]->op == GGML_OP_SSM_CONV) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_UNARY && + graph->nodes[k]->src[0] == graph->nodes[j]) { + current_set.push_back(k); + used[k] = true; + break; + } + } + } } } // Second pass grabs view nodes. @@ -15842,6 +16115,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src1_type == GGML_TYPE_F32) { switch (src0_type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index ca1a3ac25bd..b3b182fb084 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -19,7 +19,9 @@ void main() { if (idx + (num_iter-1)*num_threads < p.ne) { [[unroll]] for (uint i = 0; i < num_iter; ++i) { -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) @@ -35,7 +37,9 @@ void main() { continue; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index 9f8bfd3c182..d55e13253a8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -12,7 +12,9 @@ void main() { return; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + src0_idx(idx)]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + src0_idx(idx)]); data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 5e9f8308c1d..33c3202dbb7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters { uint sb1, sb2, sb3; uint neq1, rq3; float scale; + uint K; }; layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; @@ -101,13 +102,21 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - const uint state_base = (seq_id * H + head_id) * state_size; + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // output state layout per slot: same per-(seq,head) offset as the single-slot case. + const uint state_out_base = (seq_id * H + head_id) * state_size; + const uint state_size_per_snap = state_size * H * n_seqs; FLOAT_TYPE s_shard[ROWS_PER_LANE]; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); + s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = int(n_tokens) - int(K); + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -161,9 +170,21 @@ void main() { } attn_off += S_V * H; + + if (K > 1u) { + const int target_slot = int(t) - shift; + if (target_slot >= 0 && target_slot < int(K)) { + const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } + } } - [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + if (K == 1u) { + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index ba4c2103f0c..f4130d223b1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -44,36 +44,81 @@ void im2col(const uint ow, const uint z_idx) { const uint KHKW = p.KH * p.KW; + // Precompute base input coordinates + const int base_iw = int(ow * p.s0) - p.p0; + const int base_ih = int(oh * p.s1) - p.p1; + + // Precompute step deltas + const uint delta_ic = BLOCK_SIZE / KHKW; + const uint delta_rem = BLOCK_SIZE % KHKW; + + const uint delta_ky = delta_rem / p.KW; + const uint delta_kx = delta_rem % p.KW; + + const uint delta_ic_offset = delta_ic * p.offset_delta; + + // If using BDA mode, precompute the base pointer and step size +#if BDA + const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row; + const uint bda_step = D_SIZE * BLOCK_SIZE; +#endif + uint wg_x = gl_WorkGroupID.x; do { const uint wg_offset = wg_x * 512; - [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { - const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; + uint chw_idx = wg_offset + gidx; + + uint ic = chw_idx / KHKW; + uint rem = chw_idx % KHKW; + + uint ky = rem / p.KW; + uint kx = rem % p.KW; + uint ic_offset = src_batch + ic * p.offset_delta; + + // Initialize running pointer/index for the destination buffer +#if BDA + BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx; +#else + uint current_dst_idx = dst_row + chw_idx; +#endif + + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { if (chw_idx >= p.CHW) { return; } - const uint ic = chw_idx / KHKW; - const uint rem = chw_idx - ic * KHKW; - const uint ky = rem / p.KW; - const uint kx = rem - ky * p.KW; - - const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + const int iiw = base_iw + int(kx * p.d0); + const int iih = base_ih + int(ky * p.d1); A_TYPE val = A_TYPE(0); - if (iih < p.IH && iiw < p.IW) { - val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + if (uint(iih) < p.IH && uint(iiw) < p.IW) { + val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)]; } #if BDA - D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); - out_ptr.d = D_TYPE(val); + D_ptr(current_dst_addr).d = D_TYPE(val); + current_dst_addr += bda_step; #else - data_d[dst_row + chw_idx] = D_TYPE(val); + data_d[current_dst_idx] = D_TYPE(val); + current_dst_idx += BLOCK_SIZE; #endif + + chw_idx += BLOCK_SIZE; + ic_offset += delta_ic_offset; + kx += delta_kx; + ky += delta_ky; + + // Handle X axis wrap + uint kx_wrap = uint(kx >= p.KW); + kx -= kx_wrap * p.KW; + ky += kx_wrap; + + // Handle Y axis wrap + uint ky_wrap = uint(ky >= p.KH); + ky -= ky_wrap * p.KH; + ic_offset += ky_wrap * p.offset_delta; } wg_x += gl_NumWorkGroups.x; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index 2e53459909d..03358793140 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -9,7 +9,7 @@ uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, // Per-row offset in shared memory const uint ix = i0; #else - const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; + const uint ix = p.a_offset + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; #endif return ix; } @@ -48,6 +48,7 @@ void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]); @@ -84,6 +85,7 @@ void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -121,6 +123,7 @@ void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -176,7 +179,7 @@ void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rop return; } - const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint idst = p.d_offset + i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); const int sect_dims = p.sections[0] + p.sections[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 2e2a7e14c66..3602485b943 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -26,6 +26,9 @@ struct rope_params { uint nb11; uint nb12; uint nb13; + + uint a_offset; + uint d_offset; }; #endif // !defined(GGML_ROPE_PARAMS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp new file mode 100644 index 00000000000..8585538cbb0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" + +// Fused snake activation: y = x + sin(b * x)^2 * c +// data_a [ne0, ne1] per element activation x (A_TYPE) +// data_b [1, ne1] per channel multiplier (float) +// data_c [1, ne1] per channel inverse scale (float, precomputed as 1 / freq) +// data_d [ne0, ne1] output y (D_TYPE) +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {float data_b[];}; +layout (binding = 2) readonly buffer C {float data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t ne0; + uint32_t ne1; +} p; + +// Load A_TYPE to float +float load_val(uint32_t idx) { +#if defined(DATA_A_BF16) + return bf16_to_fp32(uint32_t(data_a[idx])); +#else + return float(data_a[idx]); +#endif +} + +// Store float as D_TYPE +void store_val(uint32_t idx, float v) { +#if defined(DATA_D_BF16) + data_d[idx] = D_TYPE(fp32_to_bf16(v)); +#else + data_d[idx] = D_TYPE(v); +#endif +} + +void main() { + const uint32_t i0 = gl_GlobalInvocationID.x; + const uint32_t i1 = gl_GlobalInvocationID.y; + if (i0 >= p.ne0 || i1 >= p.ne1) return; + + const uint32_t idx = i0 + i1 * p.ne0; + const float xi = load_val(idx); + const float s = sin(data_b[i1] * xi); + store_val(idx, xi + s * s * data_c[i1]); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp index 6802b1fc955..4cd9b8da359 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -6,12 +6,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(constant_id = 1) const uint TOKENS_PER_WG = 16; +layout(constant_id = 2) const bool APPLY_BIAS = false; +layout(constant_id = 3) const bool APPLY_SILU = false; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Src0 { float src0[]; }; layout(binding = 1) readonly buffer Src1 { float src1[]; }; -layout(binding = 2) buffer Dst { float dst[]; }; +layout(binding = 2) readonly buffer Bias { float bias[]; }; +layout(binding = 3) buffer Dst { float dst[]; }; layout(push_constant) uniform PushConstants { uint nb01; uint nb02; @@ -45,6 +48,13 @@ void main() { } } + if (APPLY_BIAS) { + sum += bias[i1]; + } + if (APPLY_SILU) { + sum = sum / (1.0f + exp(-sum)); + } + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; dst[dst_idx] = sum; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d99b2b5d802..a1d735150fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -731,6 +731,7 @@ void process_shaders() { string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_bf16_f32","copy.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); @@ -738,6 +739,7 @@ void process_shaders() { string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("contig_cpy_bf16_f32","contig_copy.comp",{{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); @@ -950,6 +952,10 @@ void process_shaders() { string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("snake_bf16", "snake.comp", {{"DATA_A_BF16", "1"}, {"DATA_D_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 62a523365b9..4c4eda1cbe5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -770,9 +770,14 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && + context.src2->ne[0] % kv_vec_head_align == 0; + // Compile with enough invocations to cover the largest reported subgroup. + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && + kv_vec_head_dims_aligned && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && @@ -808,7 +813,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = 1u; decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + decisions.wg_size = context.max_subgroup_size; if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 401c75c1230..921c12b41ac 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1234,6 +1234,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, const uint32_t h = (uint32_t) src2->ne[1]; const uint32_t n_tokens = (uint32_t) src2->ne[2]; const uint32_t n_seqs = (uint32_t) src2->ne[3]; + const uint32_t K = (uint32_t) src5->ne[1]; const float scale = 1.0f / sqrtf((float) s_v); uint32_t scale_u32; memcpy(&scale_u32, &scale, sizeof(scale_u32)); @@ -1258,6 +1259,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, (uint32_t) src0->ne[1], (uint32_t) (src2->ne[3] / src0->ne[3]), + K, scale_u32, }; @@ -1832,7 +1834,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { @@ -1953,8 +1955,11 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector reduce_params; std::vector reduce_entries; if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size; + const uint32_t reduce_wg_size = + std::max(reduce_sg_size, (uint32_t) std::min( + (uint64_t) nwg * reduce_sg_size, + ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -3542,8 +3547,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { const uint32_t kv_tile = decisions.kv_tile; - const uint32_t vec_nwg_cap = std::max( - 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index ae8036b9ac5..4133f0ab564 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -122,9 +122,9 @@ const V_CHUNKS: u32 = HEAD_DIM_V / 4u; const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; -var q_shmem: array; -var kv_shmem: array; -var p_shmem: array; +var q_shmem: array; +var kv_shmem: array; +var p_shmem: array; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @@ -169,10 +169,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let head = f32(head_idx); let slope = select(1.0, - select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), - pow(params.m0, head + 1.0), - head < params.n_head_log2), - params.max_bias > 0.0); + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { let q_tile_row = elem_idx / HEAD_DIM_QK; @@ -181,7 +181,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; q_shmem[elem_idx] = select( 0.0, - f32(Q[global_q_row_offset + q_col]) * params.scale, + Q_TYPE(Q[global_q_row_offset + q_col]) * params.scale, head_q_row < params.seq_len_q); } @@ -213,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; let k4 = K[k_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = f32(k4.x); - kv_shmem[kv_off + 1u] = f32(k4.y); - kv_shmem[kv_off + 2u] = f32(k4.z); - kv_shmem[kv_off + 3u] = f32(k4.w); + kv_shmem[kv_off + 0u] = KV_TYPE(k4.x); + kv_shmem[kv_off + 1u] = KV_TYPE(k4.y); + kv_shmem[kv_off + 2u] = KV_TYPE(k4.z); + kv_shmem[kv_off + 3u] = KV_TYPE(k4.w); } workgroupBarrier(); @@ -233,18 +233,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var dot_val = 0.0; for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { let q_off = q_base + chunk * 4u; - let qv = vec4( + let qv = vec4( q_shmem[q_off + 0u], q_shmem[q_off + 1u], q_shmem[q_off + 2u], q_shmem[q_off + 3u]); let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let kv = vec4( + let kv = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - dot_val += dot(qv, kv); + dot_val += dot(vec4(qv), vec4(kv)); } #ifdef LOGIT_SOFTCAP dot_val = params.logit_softcap * tanh(dot_val); @@ -271,7 +271,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let kv_local = sg_inv_id + slot * subgroup_size; if (row_active && kv_local < kv_count) { let p = exp(local_scores[slot] - new_max); - p_shmem[subgroup_p_offset + kv_local] = p; + p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p); local_sum += p; } } @@ -285,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; let v4 = V[v_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = f32(v4.x); - kv_shmem[kv_off + 1u] = f32(v4.y); - kv_shmem[kv_off + 2u] = f32(v4.z); - kv_shmem[kv_off + 3u] = f32(v4.w); + kv_shmem[kv_off + 0u] = KV_TYPE(v4.x); + kv_shmem[kv_off + 1u] = KV_TYPE(v4.y); + kv_shmem[kv_off + 2u] = KV_TYPE(v4.z); + kv_shmem[kv_off + 3u] = KV_TYPE(v4.w); } workgroupBarrier(); @@ -308,12 +308,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { let p = p_shmem[subgroup_p_offset + kv_local]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let v4 = vec4( + let v4 = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - acc += p * v4; + acc += f32(p) * vec4(v4); } out_regs[reg_idx] = acc; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl index f9d98fda40b..d68520f8282 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -39,6 +39,7 @@ struct Params { neq1: u32, rq3: u32, + K: u32, scale: f32, }; @@ -62,11 +63,14 @@ fn main( let iq3 = seq_id / params.rq3; let state_size = S_V * S_V; - let state_base = (seq_id * params.h + head_id) * state_size; + let state_in_base = (seq_id * params.K * params.h + head_id) * state_size; + let state_out_base = (seq_id * params.h + head_id) * state_size; + let state_size_per_snap = state_size * params.h * params.n_seqs; + let shift = i32(params.n_tokens) - i32(params.K); var state: array; for (var i = 0u; i < S_V; i++) { - state[i] = src_state[state_base + col * S_V + i]; + state[i] = src_state[state_in_base + col * S_V + i]; } var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V; @@ -123,10 +127,22 @@ fn main( dst[attn_off + col] = attn_col * params.scale; attn_off += S_V * params.h; + if (params.K > 1u) { + let target_slot = i32(t) - shift; + if (target_slot >= 0 && target_slot < i32(params.K)) { + let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base; + for (var i = 0u; i < S_V; i++) { + dst[slot_base + col * S_V + i] = state[i]; + } + } + } + workgroupBarrier(); } - for (var i = 0u; i < S_V; i++) { - dst[params.s_off + state_base + col * S_V + i] = state[i]; + if (params.K == 1u) { + for (var i = 0u; i < S_V; i++) { + dst[params.s_off + state_out_base + col * S_V + i] = state[i]; + } } } diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index f1e4f991fae..e4ba9cfbd0f 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17 + GIT_TAG 253b94ce0d7e9284c265fefb485714944caff9d3 # ZenDNN-2026-WW19 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 6a83bb6b1ec..6051d082003 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -2,6 +2,10 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" + +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" + #include "zendnnl.hpp" #include @@ -19,6 +23,8 @@ zendnnl::common::data_type_t ggml_to_zendnn_type() { return zendnnl::common::data_type_t::f32; } else if constexpr (std::is_same_v) { return zendnnl::common::data_type_t::bf16; + } else if constexpr (std::is_same_v) { + return zendnnl::common::data_type_t::s8; } else { return zendnnl::common::data_type_t::none; } @@ -48,6 +54,17 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int params.num_threads = ctx->n_threads; zendnnl::lowoha::matmul::matmul_batch_params_t batch_params; + + if constexpr (std::is_same_v) { + params.dtypes.compute = zendnnl::common::data_type_t::s8; + const int64_t num_groups = k / QK8_0; + params.dynamic_quant = true; + params.quant_params.src_scale.buff = nullptr; + params.quant_params.src_scale.dt = zendnnl::common::data_type_t::bf16; + params.quant_params.src_scale.dims = {n, num_groups}; + params.packing.pack_format_b = 1; + } + zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct( 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) n, // M: rows of B and C @@ -108,6 +125,14 @@ static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int6 (const ggml_bf16_t *)B, ldb, (float *)C, ldc); return false; + case GGML_TYPE_Q8_0: + if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32) + return false; + return ggml_zendnn_matmul( + ctx, m, n, k, + (const block_q8_0 *)A, lda, + (const float *)B, ldb, + (float *)C, ldc); default: return false; // unsupported type } @@ -145,7 +170,9 @@ static void ggml_zendnn_compute_forward_mul_mat( const int64_t r3 = ne13/ne03; void * work_data = ctx->work_data.get(); - if (src1->type != vec_dot_type) { + + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1 * ne11; const size_t nbw3 = nbw2 * ne12; @@ -171,7 +198,7 @@ static void ggml_zendnn_compute_forward_mul_mat( for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { - const void* wdata = src1->type == vec_dot_type ? src1->data : work_data; + const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; const size_t row_size = ggml_row_size(vec_dot_type, ne10); if (!ggml_zendnn_sgemm(ctx, ne01, // m @@ -184,7 +211,7 @@ static void ggml_zendnn_compute_forward_mul_mat( static_cast(dst->data) + i12*nb2 + i13*nb3, ne01, // ldc src0->type, - vec_dot_type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); } @@ -261,10 +288,15 @@ static void ggml_zendnn_compute_forward_mul_mat_id( const size_t nbw1 = row_size; const size_t nbw2 = nbw1 * ne11; const size_t nbw3 = nbw2 * ne12; - const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0; + const size_t src1_conv_size = (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) ? ne13 * nbw3 : 0; + + // For Q8_0, src1 is always F32; the gather buffer must hold F32 rows (ne10*4 bytes), + // not Q8_0-encoded rows (row_size ≈ ne10/32*34 bytes) — they differ by ~4x. + const size_t f32_row_size = (size_t)ne10 * sizeof(float); + const size_t gather_row_size = (src0->type == GGML_TYPE_Q8_0) ? f32_row_size : row_size; // size for MoE gather/scatter buffers - const size_t wdata_cur_size = max_rows * row_size; + const size_t wdata_cur_size = max_rows * gather_row_size; const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01); // allocate single buffer for all needs @@ -279,7 +311,8 @@ static void ggml_zendnn_compute_forward_mul_mat_id( char * wdata_cur = work_data + src1_conv_size; char * dst_cur = wdata_cur + wdata_cur_size; - if (src1->type != vec_dot_type) { + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { GGML_ASSERT(src1->type == GGML_TYPE_F32); #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) @@ -294,7 +327,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( } } - const void * wdata = src1->type == vec_dot_type ? src1->data : work_data; + const void * wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; // process each expert with gather -> gemm -> scatter pattern for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) { @@ -315,9 +348,9 @@ static void ggml_zendnn_compute_forward_mul_mat_id( const int64_t i12 = row_mapping.i2; std::memcpy( - wdata_cur + ir1 * row_size, - (const char *) wdata + (i11 + i12*ne11) * row_size, - row_size + wdata_cur + ir1 * gather_row_size, + (const char *) wdata + (i11 + i12*ne11) * gather_row_size, + gather_row_size ); } @@ -333,7 +366,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( dst_cur, ne01, // ldc src0->type, - vec_dot_type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) { GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); } @@ -577,6 +610,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const switch (weights->type) { case GGML_TYPE_F32: case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: return true; default: return false; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191cf2fa106..476c3079795 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); - - // concat output and new_state into a single tensor - // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. + GGML_ASSERT(state->ne[0] == S_v * S_v * H); + GGML_ASSERT(state->ne[2] == n_seqs); + GGML_ASSERT(state->ne[3] == 1); + const int64_t K = state->ne[1]; + const int64_t state_rows = K * S_v * n_seqs; + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET; diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 5a605ba344e..2c680ce9f5d 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -57ea0bc119d722d74594196cc5b494a34dd87be4 +0a37c2167fc5b81830a32d9b1691610180ed86d6