diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile
index 8e830d46251a..3d1abe612f59 100644
--- a/.devops/intel.Dockerfile
+++ b/.devops/intel.Dockerfile
@@ -33,10 +33,10 @@ RUN mkdir -p /app/full \
FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS base
-ARG IGC_VERSION=v2.30.1
-ARG IGC_VERSION_FULL=2_2.30.1+20950
-ARG COMPUTE_RUNTIME_VERSION=26.09.37435.1
-ARG COMPUTE_RUNTIME_VERSION_FULL=26.09.37435.1-0
+ARG IGC_VERSION=v2.32.7
+ARG IGC_VERSION_FULL=2_2.32.7+21184
+ARG COMPUTE_RUNTIME_VERSION=26.14.37833.4
+ARG COMPUTE_RUNTIME_VERSION_FULL=26.14.37833.4-0
ARG IGDGMM_VERSION=22.9.0
RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-core-${IGC_VERSION_FULL}_amd64.deb \
diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix
index 4e5fd00a5552..30355d2fc991 100644
--- a/.devops/nix/package.nix
+++ b/.devops/nix/package.nix
@@ -103,6 +103,7 @@ let
vulkan-headers
vulkan-loader
shaderc
+ spirv-headers
];
in
@@ -146,7 +147,6 @@ effectiveStdenv.mkDerivation (finalAttrs: {
ninja
pkg-config
git
- spirv-headers
]
++ optionals useCuda [
cudaPackages.cuda_nvcc
diff --git a/.gitignore b/.gitignore
index 417e591db6dd..11358c72855e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -110,6 +110,7 @@ uv.lock
# Nix
+flake.lock
/result
# Test binaries
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 1486171b8c56..e5dea18aeb2e 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -1570,6 +1570,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "862f827721df956049dff5ca81a57f29e575280bc622e290d3bf4e35eca29015":
# ref: https://huggingface.co/codefuse-ai/F2LLM-v2-4B
res = "f2llmv2"
+ if chkhsh == "62f6fb0a6fd5098caeabb19b07a5c1099cafc8b9c40eab6ea89ece4ec02fbc57":
+ # ref: https://huggingface.co/sarvamai/sarvam-30b
+ res = "sarvam-moe"
if res is None:
logger.warning("\n")
@@ -11591,6 +11594,34 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")
+@ModelBase.register("SarvamMoEForCausalLM", "modeling_sarvam_moe.SarvamMoEForCausalLM")
+class SarvamMoEModel(BailingMoeV2Model):
+ model_arch = gguf.MODEL_ARCH.BAILINGMOE2
+ # Sarvam-MoE shares the BailingMoeV2 architecture; only differences:
+ # - full rotary (no partial_rotary_factor)
+ # - expert bias is zero-mean normalized at load time
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+ if (rope_dim := hparams.get("head_dim")) is None:
+ rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
+ # Override the partial-rotary value written by BailingMoeV2 with the full rotary dim
+ self.gguf_writer.add_rope_dimension_count(rope_dim)
+
+ @classmethod
+ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
+ name, gen = item
+ if name.endswith(".expert_bias"):
+ # Sarvam normalizes expert bias to zero mean
+ inner = gen
+
+ def gen():
+ t = inner()
+ return t - t.mean()
+ return super().filter_tensors((name, gen))
+
+
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
class GroveMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROVEMOE
diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py
index 6e6cd057909a..8d73b1f5546a 100755
--- a/convert_hf_to_gguf_update.py
+++ b/convert_hf_to_gguf_update.py
@@ -155,6 +155,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
{"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", },
{"name": "f2llmv2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/codefuse-ai/F2LLM-v2-4B", },
+ {"name": "sarvam-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sarvamai/sarvam-30b", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md
index 7ebb4ec02979..f66facc856a7 100644
--- a/docs/backend/SYCL.md
+++ b/docs/backend/SYCL.md
@@ -737,6 +737,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.
Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
+## Compile-time Flags
+
+Pass these via `CXXFLAGS` or add a one-off `#define` to enable a flag on the spot.
+
+| Name | Function |
+|-----------------|----------------------------------------------------------------------------------|
+| DEBUG_SYCL_POOL | Enable device memory pool logging on teardown. Useful for profiling allocations. |
+
## Design Rule
- Open to all contributors.
diff --git a/flake.lock b/flake.lock
deleted file mode 100644
index d114f4422a36..000000000000
--- a/flake.lock
+++ /dev/null
@@ -1,58 +0,0 @@
-{
- "nodes": {
- "flake-parts": {
- "inputs": {
- "nixpkgs-lib": "nixpkgs-lib"
- },
- "locked": {
- "lastModified": 1730504689,
- "narHash": "sha256-hgmguH29K2fvs9szpq2r3pz2/8cJd2LPS+b4tfNFCwE=",
- "owner": "hercules-ci",
- "repo": "flake-parts",
- "rev": "506278e768c2a08bec68eb62932193e341f55c90",
- "type": "github"
- },
- "original": {
- "owner": "hercules-ci",
- "repo": "flake-parts",
- "type": "github"
- }
- },
- "nixpkgs": {
- "locked": {
- "lastModified": 1732014248,
- "narHash": "sha256-y/MEyuJ5oBWrWAic/14LaIr/u5E0wRVzyYsouYY3W6w=",
- "owner": "NixOS",
- "repo": "nixpkgs",
- "rev": "23e89b7da85c3640bbc2173fe04f4bd114342367",
- "type": "github"
- },
- "original": {
- "owner": "NixOS",
- "ref": "nixos-unstable",
- "repo": "nixpkgs",
- "type": "github"
- }
- },
- "nixpkgs-lib": {
- "locked": {
- "lastModified": 1730504152,
- "narHash": "sha256-lXvH/vOfb4aGYyvFmZK/HlsNsr/0CVWlwYvo2rxJk3s=",
- "type": "tarball",
- "url": "https://github.com/NixOS/nixpkgs/archive/cc2f28000298e1269cea6612cd06ec9979dd5d7f.tar.gz"
- },
- "original": {
- "type": "tarball",
- "url": "https://github.com/NixOS/nixpkgs/archive/cc2f28000298e1269cea6612cd06ec9979dd5d7f.tar.gz"
- }
- },
- "root": {
- "inputs": {
- "flake-parts": "flake-parts",
- "nixpkgs": "nixpkgs"
- }
- }
- },
- "root": "root",
- "version": 7
-}
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 3f01e858de79..43e22c5e5ee5 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -61,6 +61,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true);
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
@@ -1561,6 +1566,10 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
+ if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) {
+ NO_DEVICE_CODE;
+ return;
+ }
#ifdef VOLTA_MMA_AVAILABLE
if (ncols1*ncols2 < 32) {
NO_DEVICE_CODE;
diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu
index d60634cc0e96..c8281497d148 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cu
+++ b/ggml/src/ggml-cuda/fattn-tile.cu
@@ -34,6 +34,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
} break;
+ case 192: {
+ GGML_ASSERT(V->ne[0] == 128);
+ ggml_cuda_flash_attn_ext_tile_case<192, 128>(ctx, dst);
+ } break;
case 256: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
index 585f2c228532..7b0a5e5cf497 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -62,6 +62,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 64, 64)
+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64)
@@ -124,6 +130,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 128, 3, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 3, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64)
+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256)
@@ -193,6 +205,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 256, 2, 128, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64)
+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
@@ -264,6 +282,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 8, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 6, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 128, 6, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 5, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 3, 64, 64)
+
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
@@ -1250,7 +1274,20 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
}
}
- if constexpr (DKQ <= 512 && DKQ != 320) {
+ if constexpr (DKQ == 192) {
+ // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn)
+ if (use_gqa_opt && gqa_ratio % 16 == 0) {
+ launch_fattn_tile_switch_ncols1(ctx, dst);
+ return;
+ }
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ launch_fattn_tile_switch_ncols1(ctx, dst);
+ return;
+ }
+ GGML_ABORT("flash-attn tile (192/128): expected GQA ratio multiple of 8");
+ }
+
+ if constexpr (DKQ <= 512 && DKQ != 320 && DKQ != 192) {
if (use_gqa_opt && gqa_ratio % 8 == 0) {
launch_fattn_tile_switch_ncols1(ctx, dst);
return;
@@ -1303,6 +1340,7 @@ extern DECL_FATTN_TILE_CASE( 80, 80);
extern DECL_FATTN_TILE_CASE( 96, 96);
extern DECL_FATTN_TILE_CASE(112, 112);
extern DECL_FATTN_TILE_CASE(128, 128);
+extern DECL_FATTN_TILE_CASE(192, 128);
extern DECL_FATTN_TILE_CASE(256, 256);
extern DECL_FATTN_TILE_CASE(320, 256);
extern DECL_FATTN_TILE_CASE(512, 512);
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 8256591b21d6..e045b04f7271 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -139,6 +139,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(V->ne[0] == 128);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
break;
+ case 192: {
+ // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn)
+ GGML_ASSERT(V->ne[0] == 128);
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+ const bool use_gqa_opt = mask && max_bias == 0.0f;
+ GGML_ASSERT(use_gqa_opt);
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ if (gqa_ratio % 16 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 16>(ctx, dst);
+ } else {
+ GGML_ASSERT(gqa_ratio % 8 == 0);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 8>(ctx, dst);
+ }
+ } break;
case 256:
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
@@ -368,6 +384,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_NONE;
}
break;
+ case 192:
+ if (V->ne[0] != 128 || !gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ if (gqa_ratio % 8 != 0) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
case 320:
if (V->ne[0] != 256 || !gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
@@ -425,7 +449,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
- const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
+ // 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path.
+ const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// If Turing tensor cores are available, use them:
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
@@ -454,7 +479,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
int gqa_ratio_eff = 1;
- const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+ const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8;
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
gqa_ratio_eff *= 2;
}
@@ -468,7 +493,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// Use the WMMA kernel if possible:
- if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) {
+ if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 512 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
@@ -501,7 +526,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// Use MFMA flash attention for CDNA (MI100+):
- if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
+ if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
// MMA vs tile crossover benchmarked on MI300X @ d32768:
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu
index fb26abeb0dab..b2661b931624 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu
@@ -2,4 +2,5 @@
#include "../fattn-mma-f16.cuh"
+DECL_FATTN_MMA_F16_CASE(192, 128, 1, 16);
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
index 22d383173f36..6ae77bec8958 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
+DECL_FATTN_MMA_F16_CASE(192, 128, 1, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu
index f011a208cd27..fd41e71b1421 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu
@@ -2,4 +2,5 @@
#include "../fattn-mma-f16.cuh"
+DECL_FATTN_MMA_F16_CASE(192, 128, 2, 16);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
index 84b674cd05a6..9f4bef11a443 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
+DECL_FATTN_MMA_F16_CASE(192, 128, 2, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu
index f5fd0e2369cf..cc41fa52f135 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu
@@ -2,4 +2,5 @@
#include "../fattn-mma-f16.cuh"
+DECL_FATTN_MMA_F16_CASE(192, 128, 4, 16);
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
index 5906398db912..859bea5c5253 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
+DECL_FATTN_MMA_F16_CASE(192, 128, 4, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
index 4bc60d62f910..c975ce6b9b74 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
@@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
+DECL_FATTN_MMA_F16_CASE(192, 128, 8, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu
new file mode 100644
index 000000000000..b571cca0df23
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(192, 128);
diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
index 5e9a1cb2eb30..af05a9eff710 100755
--- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
+++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -3,7 +3,10 @@
from glob import glob
import os
-HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576]
+HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 192, 256, 320, 512, 576]
+
+# DKQ -> DV override for asymmetric head dims.
+HEAD_SIZES_V_OVERRIDE = {576: 512, 320: 256, 192: 128}
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
@@ -62,7 +65,7 @@ def get_short_name(long_quant_name):
os.remove(filename)
for head_size_kq in HEAD_SIZES_KQ:
- head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
+ head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq)
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
@@ -85,15 +88,17 @@ def get_short_name(long_quant_name):
if head_size_kq == 72:
continue
# Skip compilation of unused ncols2 values for niche head sizes:
+ if head_size_kq == 192 and ncols2 not in (8, 16): # MiMo-V2.5
+ continue
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
continue
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
continue
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
continue
- if head_size_kq not in (320, 576) and ncols2 in (16, 32):
+ if head_size_kq not in (192, 320, 576) and ncols2 in (16, 32):
continue
- head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
+ head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq)
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
for type in TYPES_MMQ:
diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp
index 8ddd1915c835..d3c125dbc3d1 100644
--- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp
+++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp
@@ -2261,6 +2261,58 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
return true;
}
+static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * q = op->src[0];
+ const struct ggml_tensor * k = op->src[1];
+ const struct ggml_tensor * v = op->src[2];
+ const struct ggml_tensor * g = op->src[3];
+ const struct ggml_tensor * beta = op->src[4];
+ const struct ggml_tensor * state = op->src[5];
+ const struct ggml_tensor * dst = op;
+
+ if (!q || !k || !v || !g || !beta || !state) {
+ return false;
+ }
+
+ if (q->type != GGML_TYPE_F32 || k->type != GGML_TYPE_F32 || v->type != GGML_TYPE_F32 ||
+ g->type != GGML_TYPE_F32 || beta->type != GGML_TYPE_F32 || state->type != GGML_TYPE_F32 ||
+ dst->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (!ggml_is_contiguous_rows(q) || !ggml_is_contiguous_rows(k) || !ggml_is_contiguous_rows(v) ||
+ !ggml_is_contiguous(g) || !ggml_is_contiguous(beta) || !ggml_is_contiguous(state) ||
+ !ggml_is_contiguous(dst)) {
+ return false;
+ }
+
+ const int64_t S_v = v->ne[0];
+ const int64_t H = v->ne[1];
+ const int64_t n_tokens = v->ne[2];
+ const int64_t n_seqs = v->ne[3];
+
+ if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) {
+ return false;
+ }
+ if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] <= 0 || k->ne[1] <= 0 ||
+ q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] <= 0 || k->ne[3] <= 0 ||
+ (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) {
+ return false;
+ }
+ if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) {
+ return false;
+ }
+ if (ggml_nelements(state) != S_v * S_v * H * n_seqs) {
+ return false;
+ }
+ if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) {
+ return false;
+ }
+
+ GGML_UNUSED(sess);
+ return true;
+}
+
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
@@ -2777,33 +2829,34 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
switch (t->op) {
- case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT;
- case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT;
- case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID;
- case GGML_OP_MUL: return HTP_OP_MUL;
- case GGML_OP_ADD: return HTP_OP_ADD;
- case GGML_OP_ADD_ID: return HTP_OP_ADD_ID;
- case GGML_OP_SUB: return HTP_OP_SUB;
- case GGML_OP_DIV: return HTP_OP_DIV;
- case GGML_OP_CPY: return HTP_OP_CPY;
- case GGML_OP_CONT: return HTP_OP_CPY;
- case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS;
- case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
- case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
- case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
- case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
- case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
- case GGML_OP_SCALE: return HTP_OP_SCALE;
- case GGML_OP_SQR: return HTP_OP_SQR;
- case GGML_OP_SQRT: return HTP_OP_SQRT;
- case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX;
- case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV;
- case GGML_OP_ROPE: return HTP_OP_ROPE;
- case GGML_OP_REPEAT: return HTP_OP_REPEAT;
- case GGML_OP_CUMSUM: return HTP_OP_CUMSUM;
- case GGML_OP_FILL: return HTP_OP_FILL;
- case GGML_OP_DIAG: return HTP_OP_DIAG;
- case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
+ case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT;
+ case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT;
+ case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID;
+ case GGML_OP_MUL: return HTP_OP_MUL;
+ case GGML_OP_ADD: return HTP_OP_ADD;
+ case GGML_OP_ADD_ID: return HTP_OP_ADD_ID;
+ case GGML_OP_SUB: return HTP_OP_SUB;
+ case GGML_OP_DIV: return HTP_OP_DIV;
+ case GGML_OP_CPY: return HTP_OP_CPY;
+ case GGML_OP_CONT: return HTP_OP_CPY;
+ case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS;
+ case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS;
+ case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS;
+ case GGML_OP_ARGSORT: return HTP_OP_ARGSORT;
+ case GGML_OP_L2_NORM: return HTP_OP_L2_NORM;
+ case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM;
+ case GGML_OP_SCALE: return HTP_OP_SCALE;
+ case GGML_OP_SQR: return HTP_OP_SQR;
+ case GGML_OP_SQRT: return HTP_OP_SQRT;
+ case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX;
+ case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV;
+ case GGML_OP_GATED_DELTA_NET: return HTP_OP_GATED_DELTA_NET;
+ case GGML_OP_ROPE: return HTP_OP_ROPE;
+ case GGML_OP_REPEAT: return HTP_OP_REPEAT;
+ case GGML_OP_CUMSUM: return HTP_OP_CUMSUM;
+ case GGML_OP_FILL: return HTP_OP_FILL;
+ case GGML_OP_DIAG: return HTP_OP_DIAG;
+ case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(t)) {
case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU;
@@ -3341,6 +3394,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_ssm_conv(sess, op);
break;
+ case GGML_OP_GATED_DELTA_NET:
+ supp = ggml_hexagon_supported_gated_delta_net(sess, op);
+ break;
+
case GGML_OP_CUMSUM:
supp = ggml_hexagon_supported_cumsum(sess, op);
break;
diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt
index 7c9e4cda5f1a..bcadac11f951 100644
--- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt
+++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt
@@ -37,6 +37,7 @@ add_library(${HTP_LIB} SHARED
fill-ops.c
diag-ops.c
solve-tri-ops.c
+ gated-delta-net-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c
new file mode 100644
index 000000000000..2e84badc9b7f
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c
@@ -0,0 +1,955 @@
+#include
+#include
+#include
+
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+#define HTP_GDN_MAX_SV 128
+
+struct htp_gdn_context {
+ struct htp_ops_context * octx;
+ uint32_t rows_per_thread;
+ size_t state_bytes;
+ bool use_vtcm;
+ uint8_t * vtcm_state_base;
+ size_t vtcm_state_per_thread;
+};
+
+static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul,
+ const float * restrict dot, uint32_t n) {
+ HVX_Vector acc = Q6_V_vzero();
+
+ const uint32_t epv = 128 / sizeof(float);
+ const uint32_t nvec = n / epv;
+ const uint32_t tail = n % epv;
+ for (uint32_t i = 0; i < nvec; ++i) {
+ HVX_Vector vd = hvx_vmemu(dst + i * epv);
+ HVX_Vector vm = hvx_vmem(mul + i * epv);
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
+ HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
+ hvx_vmemu(dst + i * epv) = out;
+ acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
+ }
+
+ if (tail) {
+ const uint32_t off = nvec * epv;
+ HVX_Vector vd = hvx_vmemu(dst + off);
+ HVX_Vector vm = hvx_vmem(mul + off);
+ HVX_Vector vdot = hvx_vmem(dot + off);
+ HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
+ hvx_vec_store_u(dst + off, tail * sizeof(float), out);
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
+ HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
+ acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
+ }
+
+ return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
+}
+
+static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul,
+ const float * restrict dot, uint32_t n) {
+ HVX_Vector acc = Q6_V_vzero();
+ const HVX_Vector vmul = hvx_vec_splat_f32(mul);
+
+ const uint32_t epv = 128 / sizeof(float);
+ const uint32_t nvec = n / epv;
+ const uint32_t tail = n % epv;
+ for (uint32_t i = 0; i < nvec; ++i) {
+ HVX_Vector vd = hvx_vmemu(dst + i * epv);
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
+ HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
+ hvx_vmemu(dst + i * epv) = out;
+ acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
+ }
+
+ if (tail) {
+ const uint32_t off = nvec * epv;
+ HVX_Vector vd = hvx_vmemu(dst + off);
+ HVX_Vector vdot = hvx_vmem(dot + off);
+ HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
+ hvx_vec_store_u(dst + off, tail * sizeof(float), out);
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
+ HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
+ acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
+ }
+
+ return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
+}
+
+static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src,
+ float scale, const float * restrict dot, uint32_t n) {
+ HVX_Vector acc = Q6_V_vzero();
+ const HVX_Vector vscale = hvx_vec_splat_f32(scale);
+
+ const uint32_t epv = 128 / sizeof(float);
+ const uint32_t nvec = n / epv;
+ const uint32_t tail = n % epv;
+ for (uint32_t i = 0; i < nvec; ++i) {
+ HVX_Vector vd = hvx_vmemu(dst + i * epv);
+ HVX_Vector vs = hvx_vmem(src + i * epv);
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
+ HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
+ hvx_vmemu(dst + i * epv) = out;
+ acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
+ }
+
+ if (tail) {
+ const uint32_t off = nvec * epv;
+ HVX_Vector vd = hvx_vmemu(dst + off);
+ HVX_Vector vs = hvx_vmem(src + off);
+ HVX_Vector vdot = hvx_vmem(dot + off);
+ HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
+ hvx_vec_store_u(dst + off, tail * sizeof(float), out);
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
+ HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
+ acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
+ }
+
+ return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
+}
+
+static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1,
+ float * restrict dst2, float * restrict dst3, const float * restrict mul,
+ const float * restrict dot, uint32_t n, float * restrict sums) {
+ HVX_Vector acc0 = Q6_V_vzero();
+ HVX_Vector acc1 = Q6_V_vzero();
+ HVX_Vector acc2 = Q6_V_vzero();
+ HVX_Vector acc3 = Q6_V_vzero();
+
+ const uint32_t epv = 128 / sizeof(float);
+ const uint32_t nvec = n / epv;
+ const uint32_t tail = n % epv;
+ for (uint32_t i = 0; i < nvec; ++i) {
+ HVX_Vector vm = hvx_vmem(mul + i * epv);
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
+
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm);
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm);
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm);
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm);
+
+ hvx_vmemu(dst0 + i * epv) = out0;
+ hvx_vmemu(dst1 + i * epv) = out1;
+ hvx_vmemu(dst2 + i * epv) = out2;
+ hvx_vmemu(dst3 + i * epv) = out3;
+
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
+ }
+
+ if (tail) {
+ const uint32_t off = nvec * epv;
+ HVX_Vector vm = hvx_vmem(mul + off);
+ HVX_Vector vdot = hvx_vmem(dot + off);
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
+ HVX_Vector zero = Q6_V_vzero();
+
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm);
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm);
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm);
+
+ hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
+ hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
+ hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
+ hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
+
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
+ }
+
+ HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
+ hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
+}
+
+static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restrict dst1,
+ float * restrict dst2, float * restrict dst3, float mul,
+ const float * restrict dot, uint32_t n, float * restrict sums) {
+ HVX_Vector acc0 = Q6_V_vzero();
+ HVX_Vector acc1 = Q6_V_vzero();
+ HVX_Vector acc2 = Q6_V_vzero();
+ HVX_Vector acc3 = Q6_V_vzero();
+ const HVX_Vector vmul = hvx_vec_splat_f32(mul);
+
+ const uint32_t epv = 128 / sizeof(float);
+ const uint32_t nvec = n / epv;
+ const uint32_t tail = n % epv;
+ for (uint32_t i = 0; i < nvec; ++i) {
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
+
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul);
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul);
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul);
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul);
+
+ hvx_vmemu(dst0 + i * epv) = out0;
+ hvx_vmemu(dst1 + i * epv) = out1;
+ hvx_vmemu(dst2 + i * epv) = out2;
+ hvx_vmemu(dst3 + i * epv) = out3;
+
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
+ }
+
+ if (tail) {
+ const uint32_t off = nvec * epv;
+ HVX_Vector vdot = hvx_vmem(dot + off);
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
+ HVX_Vector zero = Q6_V_vzero();
+
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul);
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul);
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul);
+
+ hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
+ hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
+ hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
+ hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
+
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
+ }
+
+ HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
+ hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
+}
+
+static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restrict dst1,
+ float * restrict dst2, float * restrict dst3, const float * restrict src,
+ const float * restrict scale, const float * restrict dot, uint32_t n,
+ float * restrict sums) {
+ HVX_Vector acc0 = Q6_V_vzero();
+ HVX_Vector acc1 = Q6_V_vzero();
+ HVX_Vector acc2 = Q6_V_vzero();
+ HVX_Vector acc3 = Q6_V_vzero();
+ const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]);
+ const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]);
+ const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]);
+ const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]);
+
+ const uint32_t epv = 128 / sizeof(float);
+ const uint32_t nvec = n / epv;
+ const uint32_t tail = n % epv;
+ for (uint32_t i = 0; i < nvec; ++i) {
+ HVX_Vector vs = hvx_vmem(src + i * epv);
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
+
+ HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0));
+ HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1));
+ HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2));
+ HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3));
+
+ hvx_vmemu(dst0 + i * epv) = out0;
+ hvx_vmemu(dst1 + i * epv) = out1;
+ hvx_vmemu(dst2 + i * epv) = out2;
+ hvx_vmemu(dst3 + i * epv) = out3;
+
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
+ }
+
+ if (tail) {
+ const uint32_t off = nvec * epv;
+ HVX_Vector vs = hvx_vmem(src + off);
+ HVX_Vector vdot = hvx_vmem(dot + off);
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
+ HVX_Vector zero = Q6_V_vzero();
+
+ HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
+ HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1));
+ HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2));
+ HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3));
+
+ hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
+ hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
+ hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
+ hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
+
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
+ }
+
+ HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
+ hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
+}
+
+static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1,
+ float * restrict dst2, float * restrict dst3, float * restrict dst4,
+ float * restrict dst5, float * restrict dst6, float * restrict dst7,
+ const float * restrict mul, const float * restrict dot, uint32_t n,
+ float * restrict sums) {
+ HVX_Vector acc0 = Q6_V_vzero();
+ HVX_Vector acc1 = Q6_V_vzero();
+ HVX_Vector acc2 = Q6_V_vzero();
+ HVX_Vector acc3 = Q6_V_vzero();
+ HVX_Vector acc4 = Q6_V_vzero();
+ HVX_Vector acc5 = Q6_V_vzero();
+ HVX_Vector acc6 = Q6_V_vzero();
+ HVX_Vector acc7 = Q6_V_vzero();
+
+ const uint32_t epv = 128 / sizeof(float);
+ const uint32_t nvec = n / epv;
+ const uint32_t tail = n % epv;
+ for (uint32_t i = 0; i < nvec; ++i) {
+ HVX_Vector vm = hvx_vmem(mul + i * epv);
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
+
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm);
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm);
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm);
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm);
+ HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vm);
+ HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vm);
+ HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vm);
+ HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vm);
+
+ hvx_vmemu(dst0 + i * epv) = out0;
+ hvx_vmemu(dst1 + i * epv) = out1;
+ hvx_vmemu(dst2 + i * epv) = out2;
+ hvx_vmemu(dst3 + i * epv) = out3;
+ hvx_vmemu(dst4 + i * epv) = out4;
+ hvx_vmemu(dst5 + i * epv) = out5;
+ hvx_vmemu(dst6 + i * epv) = out6;
+ hvx_vmemu(dst7 + i * epv) = out7;
+
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
+ acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
+ acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
+ acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
+ acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
+ }
+
+ if (tail) {
+ const uint32_t off = nvec * epv;
+ HVX_Vector vm = hvx_vmem(mul + off);
+ HVX_Vector vdot = hvx_vmem(dot + off);
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
+ HVX_Vector zero = Q6_V_vzero();
+
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm);
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm);
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm);
+ HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vm);
+ HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vm);
+ HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm);
+ HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm);
+
+ hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
+ hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
+ hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
+ hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
+ hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
+ hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
+ hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
+ hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
+
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
+ acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
+ acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
+ acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
+ acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
+ }
+
+ HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
+ HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
+ hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
+ hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
+}
+
+static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restrict dst1,
+ float * restrict dst2, float * restrict dst3, float * restrict dst4,
+ float * restrict dst5, float * restrict dst6, float * restrict dst7,
+ float mul, const float * restrict dot, uint32_t n, float * restrict sums) {
+ HVX_Vector acc0 = Q6_V_vzero();
+ HVX_Vector acc1 = Q6_V_vzero();
+ HVX_Vector acc2 = Q6_V_vzero();
+ HVX_Vector acc3 = Q6_V_vzero();
+ HVX_Vector acc4 = Q6_V_vzero();
+ HVX_Vector acc5 = Q6_V_vzero();
+ HVX_Vector acc6 = Q6_V_vzero();
+ HVX_Vector acc7 = Q6_V_vzero();
+ const HVX_Vector vmul = hvx_vec_splat_f32(mul);
+
+ const uint32_t epv = 128 / sizeof(float);
+ const uint32_t nvec = n / epv;
+ const uint32_t tail = n % epv;
+ for (uint32_t i = 0; i < nvec; ++i) {
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
+
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul);
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul);
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul);
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul);
+ HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vmul);
+ HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vmul);
+ HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vmul);
+ HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vmul);
+
+ hvx_vmemu(dst0 + i * epv) = out0;
+ hvx_vmemu(dst1 + i * epv) = out1;
+ hvx_vmemu(dst2 + i * epv) = out2;
+ hvx_vmemu(dst3 + i * epv) = out3;
+ hvx_vmemu(dst4 + i * epv) = out4;
+ hvx_vmemu(dst5 + i * epv) = out5;
+ hvx_vmemu(dst6 + i * epv) = out6;
+ hvx_vmemu(dst7 + i * epv) = out7;
+
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
+ acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
+ acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
+ acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
+ acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
+ }
+
+ if (tail) {
+ const uint32_t off = nvec * epv;
+ HVX_Vector vdot = hvx_vmem(dot + off);
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
+ HVX_Vector zero = Q6_V_vzero();
+
+ HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
+ HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul);
+ HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul);
+ HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul);
+ HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vmul);
+ HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vmul);
+ HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul);
+ HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul);
+
+ hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
+ hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
+ hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
+ hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
+ hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
+ hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
+ hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
+ hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
+
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
+ acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
+ acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
+ acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
+ acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
+ }
+
+ HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
+ HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
+ hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
+ hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
+}
+
+static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restrict dst1,
+ float * restrict dst2, float * restrict dst3, float * restrict dst4,
+ float * restrict dst5, float * restrict dst6, float * restrict dst7,
+ const float * restrict src, const float * restrict scale,
+ const float * restrict dot, uint32_t n, float * restrict sums) {
+ HVX_Vector acc0 = Q6_V_vzero();
+ HVX_Vector acc1 = Q6_V_vzero();
+ HVX_Vector acc2 = Q6_V_vzero();
+ HVX_Vector acc3 = Q6_V_vzero();
+ HVX_Vector acc4 = Q6_V_vzero();
+ HVX_Vector acc5 = Q6_V_vzero();
+ HVX_Vector acc6 = Q6_V_vzero();
+ HVX_Vector acc7 = Q6_V_vzero();
+ const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]);
+ const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]);
+ const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]);
+ const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]);
+ const HVX_Vector scale4 = hvx_vec_splat_f32(scale[4]);
+ const HVX_Vector scale5 = hvx_vec_splat_f32(scale[5]);
+ const HVX_Vector scale6 = hvx_vec_splat_f32(scale[6]);
+ const HVX_Vector scale7 = hvx_vec_splat_f32(scale[7]);
+
+ const uint32_t epv = 128 / sizeof(float);
+ const uint32_t nvec = n / epv;
+ const uint32_t tail = n % epv;
+ for (uint32_t i = 0; i < nvec; ++i) {
+ HVX_Vector vs = hvx_vmem(src + i * epv);
+ HVX_Vector vdot = hvx_vmem(dot + i * epv);
+
+ HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0));
+ HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1));
+ HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2));
+ HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3));
+ HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + i * epv), hvx_vec_mul_f32_f32(vs, scale4));
+ HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + i * epv), hvx_vec_mul_f32_f32(vs, scale5));
+ HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + i * epv), hvx_vec_mul_f32_f32(vs, scale6));
+ HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + i * epv), hvx_vec_mul_f32_f32(vs, scale7));
+
+ hvx_vmemu(dst0 + i * epv) = out0;
+ hvx_vmemu(dst1 + i * epv) = out1;
+ hvx_vmemu(dst2 + i * epv) = out2;
+ hvx_vmemu(dst3 + i * epv) = out3;
+ hvx_vmemu(dst4 + i * epv) = out4;
+ hvx_vmemu(dst5 + i * epv) = out5;
+ hvx_vmemu(dst6 + i * epv) = out6;
+ hvx_vmemu(dst7 + i * epv) = out7;
+
+ acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
+ acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
+ acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
+ acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
+ acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
+ acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
+ acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
+ acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
+ }
+
+ if (tail) {
+ const uint32_t off = nvec * epv;
+ HVX_Vector vs = hvx_vmem(src + off);
+ HVX_Vector vdot = hvx_vmem(dot + off);
+ HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
+ HVX_Vector zero = Q6_V_vzero();
+
+ HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
+ HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1));
+ HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2));
+ HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3));
+ HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + off), hvx_vec_mul_f32_f32(vs, scale4));
+ HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + off), hvx_vec_mul_f32_f32(vs, scale5));
+ HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6));
+ HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7));
+
+ hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
+ hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
+ hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
+ hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
+ hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
+ hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
+ hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
+ hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
+
+ acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
+ acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
+ acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
+ acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
+ acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
+ acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
+ acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
+ acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
+ }
+
+ HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
+ HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
+ hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
+ hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
+}
+
+static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_gdn_context * gctx = (struct htp_gdn_context *) data;
+ struct htp_ops_context * octx = gctx->octx;
+
+ const struct htp_tensor * q = octx->src[0];
+ const struct htp_tensor * k = octx->src[1];
+ const struct htp_tensor * v = octx->src[2];
+ const struct htp_tensor * g = octx->src[3];
+ const struct htp_tensor * beta = octx->src[4];
+ const struct htp_tensor * state = octx->src[5];
+ const struct htp_tensor * dst = octx->dst;
+
+ const uint32_t S_v = v->ne[0];
+ const uint32_t H = v->ne[1];
+ const uint32_t n_tokens = v->ne[2];
+ const uint32_t n_seqs = v->ne[3];
+
+ const uint32_t total_rows = H * n_seqs;
+ if (ith >= total_rows) {
+ return;
+ }
+
+ const uint32_t rq3 = n_seqs / q->ne[3];
+ const uint32_t rk3 = n_seqs / k->ne[3];
+ const float scale = 1.0f / sqrtf((float) S_v);
+
+ float * dst_base = (float *) (uintptr_t) dst->data;
+ float * state_out_base = dst_base + (uint64_t) S_v * H * n_tokens * n_seqs;
+ const float * state_in_base = (const float *) (uintptr_t) state->data;
+
+ const bool kda = (g->ne[0] == S_v);
+ float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
+ float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
+ float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
+ float local_sums[4] __attribute__((aligned(128)));
+
+ for (uint32_t ir = ith; ir < total_rows; ir += nth) {
+ const uint32_t iv1 = ir % H;
+ const uint32_t iv3 = ir / H;
+
+ const uint32_t iq1 = iv1 % q->ne[1];
+ const uint32_t ik1 = iv1 % k->ne[1];
+ const uint32_t iq3 = iv3 / rq3;
+ const uint32_t ik3 = iv3 / rk3;
+
+ float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
+ const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
+
+ memcpy(s_out, s_in, gctx->state_bytes);
+ float * s_work = s_out;
+
+ float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v;
+
+ for (uint32_t t = 0; t < n_tokens; ++t) {
+ const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data +
+ (uint64_t) iq3 * q->nb[3] + (uint64_t) t * q->nb[2] + (uint64_t) iq1 * q->nb[1]);
+ const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data +
+ (uint64_t) ik3 * k->nb[3] + (uint64_t) t * k->nb[2] + (uint64_t) ik1 * k->nb[1]);
+ const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data +
+ (uint64_t) iv3 * v->nb[3] + (uint64_t) t * v->nb[2] + (uint64_t) iv1 * v->nb[1]);
+ const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data +
+ (uint64_t) iv3 * g->nb[3] + (uint64_t) t * g->nb[2] + (uint64_t) iv1 * g->nb[1]);
+ const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
+ (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]);
+
+ memcpy(local_q, q_t, (size_t) S_v * sizeof(float));
+ memcpy(local_k, k_t, (size_t) S_v * sizeof(float));
+
+ if (kda) {
+ hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
+
+ uint32_t j = 0;
+ for (; j + 4 <= S_v; j += 4) {
+ float * row0 = s_work + (uint64_t) (j + 0) * S_v;
+ float * row1 = s_work + (uint64_t) (j + 1) * S_v;
+ float * row2 = s_work + (uint64_t) (j + 2) * S_v;
+ float * row3 = s_work + (uint64_t) (j + 3) * S_v;
+ gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
+ float local_delta_b[4] __attribute__((aligned(128)));
+ for (uint32_t r = 0; r < 4; ++r) {
+ local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
+ }
+ gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
+ for (uint32_t r = 0; r < 4; ++r) {
+ attn_data[j + r] = local_sums[r] * scale;
+ }
+ }
+ for (; j < S_v; ++j) {
+ float * row = s_work + (uint64_t) j * S_v;
+ const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
+ const float dj = (v_t[j] - sum) * beta_val;
+ attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
+ }
+ } else {
+ const float gate = expf(g_t[0]);
+ uint32_t j = 0;
+ for (; j + 4 <= S_v; j += 4) {
+ float * row0 = s_work + (uint64_t) (j + 0) * S_v;
+ float * row1 = s_work + (uint64_t) (j + 1) * S_v;
+ float * row2 = s_work + (uint64_t) (j + 2) * S_v;
+ float * row3 = s_work + (uint64_t) (j + 3) * S_v;
+ gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
+ float local_delta_b[4] __attribute__((aligned(128)));
+ for (uint32_t r = 0; r < 4; ++r) {
+ local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
+ }
+ gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
+ for (uint32_t r = 0; r < 4; ++r) {
+ attn_data[j + r] = local_sums[r] * scale;
+ }
+ }
+ for (; j < S_v; ++j) {
+ float * row = s_work + (uint64_t) j * S_v;
+ const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
+ const float dj = (v_t[j] - sum) * beta_val;
+ attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
+ }
+ }
+
+ attn_data += (uint64_t) S_v * H;
+ }
+ }
+}
+
+static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_gdn_context * gctx = (struct htp_gdn_context *) data;
+ struct htp_ops_context * octx = gctx->octx;
+
+ const struct htp_tensor * q = octx->src[0];
+ const struct htp_tensor * k = octx->src[1];
+ const struct htp_tensor * v = octx->src[2];
+ const struct htp_tensor * g = octx->src[3];
+ const struct htp_tensor * beta = octx->src[4];
+ const struct htp_tensor * state = octx->src[5];
+ const struct htp_tensor * dst = octx->dst;
+
+ const uint32_t S_v = v->ne[0];
+ const uint32_t H = v->ne[1];
+ const uint32_t n_seqs = v->ne[3];
+
+ const uint32_t total_rows = H * n_seqs;
+ if (ith >= total_rows) {
+ return;
+ }
+
+ const uint32_t rq3 = n_seqs / q->ne[3];
+ const uint32_t rk3 = n_seqs / k->ne[3];
+ const float scale = 1.0f / sqrtf((float) S_v);
+
+ float * dst_base = (float *) (uintptr_t) dst->data;
+ float * state_out_base = dst_base + (uint64_t) S_v * H * n_seqs;
+ const float * state_in_base = (const float *) (uintptr_t) state->data;
+
+ const bool kda = (g->ne[0] == S_v);
+ float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
+ float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
+ float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
+ float local_sums[8] __attribute__((aligned(128)));
+
+ dma_queue * dma = octx->ctx->dma[ith];
+
+ uint8_t * spad = NULL;
+ if (gctx->use_vtcm) {
+ spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith;
+ }
+
+ for (uint32_t ir = ith; ir < total_rows; ir += nth) {
+ const uint32_t iv1 = ir % H;
+ const uint32_t iv3 = ir / H;
+
+ const uint32_t iq1 = iv1 % q->ne[1];
+ const uint32_t ik1 = iv1 % k->ne[1];
+ const uint32_t iq3 = iv3 / rq3;
+ const uint32_t ik3 = iv3 / rk3;
+
+ float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
+ const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
+ float * s_work;
+
+ if (spad) {
+ dma_queue_push(dma, dma_make_ptr(spad, s_in),
+ S_v * sizeof(float), S_v * sizeof(float),
+ S_v * sizeof(float), S_v);
+ dma_queue_pop(dma);
+ s_work = (float *) spad;
+ } else {
+ s_work = s_out;
+ memcpy(s_work, s_in, gctx->state_bytes);
+ }
+
+ float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v;
+
+ const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data +
+ (uint64_t) iq3 * q->nb[3] + (uint64_t) iq1 * q->nb[1]);
+ const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data +
+ (uint64_t) ik3 * k->nb[3] + (uint64_t) ik1 * k->nb[1]);
+ const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data +
+ (uint64_t) iv3 * v->nb[3] + (uint64_t) iv1 * v->nb[1]);
+ const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data +
+ (uint64_t) iv3 * g->nb[3] + (uint64_t) iv1 * g->nb[1]);
+ const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
+ (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]);
+
+ memcpy(local_q, q_t, (size_t) S_v * sizeof(float));
+ memcpy(local_k, k_t, (size_t) S_v * sizeof(float));
+
+ if (kda) {
+ hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
+
+ uint32_t j = 0;
+ for (; j + 8 <= S_v; j += 8) {
+ float * row0 = s_work + (uint64_t) (j + 0) * S_v;
+ float * row1 = s_work + (uint64_t) (j + 1) * S_v;
+ float * row2 = s_work + (uint64_t) (j + 2) * S_v;
+ float * row3 = s_work + (uint64_t) (j + 3) * S_v;
+ float * row4 = s_work + (uint64_t) (j + 4) * S_v;
+ float * row5 = s_work + (uint64_t) (j + 5) * S_v;
+ float * row6 = s_work + (uint64_t) (j + 6) * S_v;
+ float * row7 = s_work + (uint64_t) (j + 7) * S_v;
+ gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
+ local_gate, local_k, S_v, local_sums);
+ float local_delta_b[8] __attribute__((aligned(128)));
+ for (uint32_t r = 0; r < 8; ++r) {
+ local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
+ }
+ gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
+ local_k, local_delta_b, local_q, S_v, local_sums);
+ for (uint32_t r = 0; r < 8; ++r) {
+ attn_data[j + r] = local_sums[r] * scale;
+ }
+ }
+ for (; j + 4 <= S_v; j += 4) {
+ float * row0 = s_work + (uint64_t) (j + 0) * S_v;
+ float * row1 = s_work + (uint64_t) (j + 1) * S_v;
+ float * row2 = s_work + (uint64_t) (j + 2) * S_v;
+ float * row3 = s_work + (uint64_t) (j + 3) * S_v;
+ gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
+ float local_delta_b[4] __attribute__((aligned(128)));
+ for (uint32_t r = 0; r < 4; ++r) {
+ local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
+ }
+ gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
+ for (uint32_t r = 0; r < 4; ++r) {
+ attn_data[j + r] = local_sums[r] * scale;
+ }
+ }
+ for (; j < S_v; ++j) {
+ float * row = s_work + (uint64_t) j * S_v;
+ const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
+ const float dj = (v_t[j] - sum) * beta_val;
+ attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
+ }
+ } else {
+ const float gate = expf(g_t[0]);
+ uint32_t j = 0;
+ for (; j + 8 <= S_v; j += 8) {
+ float * row0 = s_work + (uint64_t) (j + 0) * S_v;
+ float * row1 = s_work + (uint64_t) (j + 1) * S_v;
+ float * row2 = s_work + (uint64_t) (j + 2) * S_v;
+ float * row3 = s_work + (uint64_t) (j + 3) * S_v;
+ float * row4 = s_work + (uint64_t) (j + 4) * S_v;
+ float * row5 = s_work + (uint64_t) (j + 5) * S_v;
+ float * row6 = s_work + (uint64_t) (j + 6) * S_v;
+ float * row7 = s_work + (uint64_t) (j + 7) * S_v;
+ gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
+ gate, local_k, S_v, local_sums);
+ float local_delta_b[8] __attribute__((aligned(128)));
+ for (uint32_t r = 0; r < 8; ++r) {
+ local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
+ }
+ gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
+ local_k, local_delta_b, local_q, S_v, local_sums);
+ for (uint32_t r = 0; r < 8; ++r) {
+ attn_data[j + r] = local_sums[r] * scale;
+ }
+ }
+ for (; j + 4 <= S_v; j += 4) {
+ float * row0 = s_work + (uint64_t) (j + 0) * S_v;
+ float * row1 = s_work + (uint64_t) (j + 1) * S_v;
+ float * row2 = s_work + (uint64_t) (j + 2) * S_v;
+ float * row3 = s_work + (uint64_t) (j + 3) * S_v;
+ gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
+ float local_delta_b[4] __attribute__((aligned(128)));
+ for (uint32_t r = 0; r < 4; ++r) {
+ local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
+ }
+ gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
+ for (uint32_t r = 0; r < 4; ++r) {
+ attn_data[j + r] = local_sums[r] * scale;
+ }
+ }
+ for (; j < S_v; ++j) {
+ float * row = s_work + (uint64_t) j * S_v;
+ const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
+ const float dj = (v_t[j] - sum) * beta_val;
+ attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
+ }
+ }
+
+ if (spad) {
+ dma_queue_push(dma, dma_make_ptr(s_out, spad),
+ S_v * sizeof(float), S_v * sizeof(float),
+ S_v * sizeof(float), S_v);
+ dma_queue_pop(dma);
+ }
+ }
+}
+
+int op_gated_delta_net(struct htp_ops_context * octx) {
+ const struct htp_tensor * q = octx->src[0];
+ const struct htp_tensor * k = octx->src[1];
+ const struct htp_tensor * v = octx->src[2];
+ const struct htp_tensor * g = octx->src[3];
+ const struct htp_tensor * beta = octx->src[4];
+ const struct htp_tensor * state = octx->src[5];
+ const struct htp_tensor * dst = octx->dst;
+
+ if (!q || !k || !v || !g || !beta || !state || !dst) {
+ return HTP_STATUS_INVAL_PARAMS;
+ }
+
+ if (q->type != HTP_TYPE_F32 || k->type != HTP_TYPE_F32 || v->type != HTP_TYPE_F32 ||
+ g->type != HTP_TYPE_F32 || beta->type != HTP_TYPE_F32 || state->type != HTP_TYPE_F32 ||
+ dst->type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ const uint32_t S_v = v->ne[0];
+ const uint32_t H = v->ne[1];
+ const uint32_t n_tokens = v->ne[2];
+ const uint32_t n_seqs = v->ne[3];
+
+ if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+ if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+ if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] == 0 || k->ne[1] == 0 ||
+ q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] == 0 || k->ne[3] == 0 ||
+ (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+ if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+ if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+ return HTP_STATUS_OK;
+ }
+
+ struct htp_gdn_context gctx;
+ gctx.octx = octx;
+ gctx.rows_per_thread = (H * n_seqs + octx->n_threads - 1) / octx->n_threads;
+ gctx.state_bytes = (size_t) S_v * S_v * sizeof(float);
+
+ size_t state_aligned = (size_t) S_v * S_v * sizeof(float);
+ state_aligned = (state_aligned + 127) & ~(size_t)127;
+
+ gctx.use_vtcm = false;
+ gctx.vtcm_state_base = NULL;
+ gctx.vtcm_state_per_thread = 0;
+
+ if (n_tokens == 1 && octx->ctx->vtcm_base) {
+ size_t vtcm_total = state_aligned * octx->n_threads;
+ if (octx->ctx->vtcm_size >= vtcm_total) {
+ gctx.use_vtcm = true;
+ gctx.vtcm_state_base = octx->ctx->vtcm_base;
+ gctx.vtcm_state_per_thread = state_aligned;
+ }
+ }
+
+ if (n_tokens == 1) {
+ worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads);
+ } else {
+ worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_pp_thread, &gctx, octx->n_threads);
+ }
+
+ return HTP_STATUS_OK;
+}
diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h
index e9c563ca887b..92f02eac6e31 100644
--- a/ggml/src/ggml-hexagon/htp/htp-ctx.h
+++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h
@@ -106,5 +106,6 @@ int op_cumsum(struct htp_ops_context * octx);
int op_fill(struct htp_ops_context * octx);
int op_diag(struct htp_ops_context * octx);
int op_solve_tri(struct htp_ops_context * octx);
+int op_gated_delta_net(struct htp_ops_context * octx);
#endif /* HTP_CTX_H */
diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h
index ef96ad38278c..6203e3848b94 100644
--- a/ggml/src/ggml-hexagon/htp/htp-ops.h
+++ b/ggml/src/ggml-hexagon/htp/htp-ops.h
@@ -84,6 +84,7 @@ enum htp_op_code {
HTP_OP_DIAG,
HTP_OP_SOLVE_TRI,
HTP_OP_L2_NORM,
+ HTP_OP_GATED_DELTA_NET,
HTP_OP_INVALID
};
diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c
index e18f1a0e61eb..fa1e0698f4aa 100644
--- a/ggml/src/ggml-hexagon/htp/main.c
+++ b/ggml/src/ggml-hexagon/htp/main.c
@@ -594,6 +594,9 @@ static int execute_op(struct htp_ops_context * octx) {
case HTP_OP_SOLVE_TRI:
return op_solve_tri(octx);
+ case HTP_OP_GATED_DELTA_NET:
+ return op_gated_delta_net(octx);
+
case HTP_OP_INVALID:
break;
diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt
index 8e589fa238dc..8f44c6ed080a 100644
--- a/ggml/src/ggml-sycl/CMakeLists.txt
+++ b/ggml/src/ggml-sycl/CMakeLists.txt
@@ -135,7 +135,11 @@ endif()
if (GGML_SYCL_TARGET STREQUAL "INTEL")
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
- target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required)
+ if (NOT GGML_SYCL_DEVICE_ARCH)
+ target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required)
+ else()
+ message(STATUS "Skipping -ze-intel-greater-than-4GB-buffer-required for spir64_gen AOT")
+ endif()
# Link against Intel oneMKL
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
@@ -160,7 +164,15 @@ if (GGML_SYCL_HOST_MEM_FALLBACK)
endif()
if (GGML_SYCL_DEVICE_ARCH)
- target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
- target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
+ message(STATUS "GGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} (AOT via spir64_gen)")
+ target_compile_options(
+ ggml-sycl PRIVATE
+ -fsycl-targets=spir64_gen
+ "SHELL:-Xsycl-target-backend=spir64_gen \"-device ${GGML_SYCL_DEVICE_ARCH}\""
+ )
+ target_link_options(
+ ggml-sycl PRIVATE
+ -fsycl-targets=spir64_gen
+ "SHELL:-Xsycl-target-backend=spir64_gen \"-device ${GGML_SYCL_DEVICE_ARCH}\""
+ )
endif()
-
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
index 5abf22906518..eec36e8db9a2 100644
--- a/ggml/src/ggml-sycl/common.hpp
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -25,6 +25,7 @@
#include "presets.hpp"
#include "type.hpp"
#include "sycl_hw.hpp"
+#include "fattn-buffers.hpp"
namespace syclexp = sycl::ext::oneapi::experimental;
@@ -404,12 +405,16 @@ struct ggml_backend_sycl_context {
std::unique_ptr pools[GGML_SYCL_MAX_DEVICES];
std::unordered_map>> scratchpad_map;
+ std::unique_ptr fattn_bufs[GGML_SYCL_MAX_DEVICES];
+
std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES];
static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device);
static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device);
+ static std::unique_ptr new_fattn_kv_buffers(queue_ptr qptr, int device);
+
ggml_sycl_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(stream(device,0), device);
@@ -421,6 +426,17 @@ struct ggml_backend_sycl_context {
return pool(device);
}
+ ggml_sycl_fattn_kv_buffers & fattn_buffers(int device) {
+ if (fattn_bufs[device] == nullptr) {
+ fattn_bufs[device] = new_fattn_kv_buffers(stream(device, 0), device);
+ }
+ return *fattn_bufs[device];
+ }
+
+ ggml_sycl_fattn_kv_buffers & fattn_buffers() {
+ return fattn_buffers(device);
+ }
+
#ifdef GGML_SYCL_GRAPH
std::unique_ptr> exec_graph = nullptr;
#endif
diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp
index 67b9c06f3e44..576f19d79ae9 100644
--- a/ggml/src/ggml-sycl/convert.cpp
+++ b/ggml/src/ggml-sycl/convert.cpp
@@ -252,6 +252,23 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
#endif
}
+template
+static void dequantize_row_q5_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
+ const int64_t nb = k / QK_K;
+
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
+
+ stream->submit([&](sycl::handler & cgh) {
+ sycl::local_accessor scale_local_acc(sycl::range<1>(K_SCALE_SIZE), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q5_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
+ });
+ });
+}
+
template
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
@@ -643,7 +660,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return dequantize_row_q4_K_sycl;
}
case GGML_TYPE_Q5_K:
- return dequantize_row_q5_K_sycl;
+ if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
+ return dequantize_row_q5_K_sycl_reorder;
+ } else {
+ return dequantize_row_q5_K_sycl;
+ }
case GGML_TYPE_Q6_K:
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q6_K_sycl_reorder;
@@ -718,7 +739,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return dequantize_row_q4_K_sycl;
}
case GGML_TYPE_Q5_K:
- return dequantize_row_q5_K_sycl;
+ if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
+ return dequantize_row_q5_K_sycl_reorder;
+ } else {
+ return dequantize_row_q5_K_sycl;
+ }
case GGML_TYPE_Q6_K:
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q6_K_sycl_reorder;
diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp
index 19fa88680d69..2324bfacd220 100644
--- a/ggml/src/ggml-sycl/dequantize.hpp
+++ b/ggml/src/ggml-sycl/dequantize.hpp
@@ -537,6 +537,63 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
#endif
}
+template
+static void dequantize_block_q5_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ uint8_t * scales_local, const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
+ const int64_t ib = item_ct1.get_group(2);
+
+#if QK_K == 256
+ // assume 64 threads
+ const int64_t tid = item_ct1.get_local_id(2);
+ const int64_t il = tid / 16; // 0...3
+ const int64_t ir = tid % 16; // 0...15
+ const int64_t is = 2 * il;
+
+ dst_t * y = yy + ib * QK_K + 64 * il + 2 * ir;
+
+ const uint8_t * base = static_cast(vx);
+
+ // Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales (K_SCALE_SIZE per block)] [dm (half2 per block)]
+ const size_t qs_offset = ib * (QK_K / 2);
+ const size_t qh_offset = n_blocks * (QK_K / 2) + ib * (QK_K / 8);
+ const size_t scales_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + ib * K_SCALE_SIZE;
+ const size_t dm_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + n_blocks * K_SCALE_SIZE + ib * sizeof(ggml_half2);
+
+ const uint8_t * qs_ptr = base + qs_offset;
+ const uint8_t * qh_ptr = base + qh_offset;
+ const uint8_t * scales_ptr = base + scales_offset;
+ const ggml_half2 dm_values = *reinterpret_cast(base + dm_offset);
+
+ const float dall = dm_values.x();
+ const float dmin = dm_values.y();
+
+ const uint8_t * ql = qs_ptr + 32 * il + 2 * ir;
+ const uint8_t * qh = qh_ptr + 2 * ir;
+
+ if (tid < K_SCALE_SIZE) {
+ scales_local[tid] = scales_ptr[tid];
+ }
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ uint8_t sc, m;
+ get_scale_min_k4(is + 0, scales_local, sc, m);
+ const float d1 = dall * sc; const float m1 = dmin * m;
+ get_scale_min_k4(is + 1, scales_local, sc, m);
+ const float d2 = dall * sc; const float m2 = dmin * m;
+
+ uint8_t hm = 1 << (2 * il);
+ y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
+ y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
+ hm <<= 1;
+ y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
+ y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+#else
+ GGML_UNUSED(ib); GGML_UNUSED(tid); GGML_UNUSED(yy); GGML_UNUSED(scales_local); GGML_UNUSED(n_blocks);
+ GGML_ABORT("Q5_K reorder dequantize not supported for QK_K != 256");
+#endif
+}
+
template
static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
diff --git a/ggml/src/ggml-sycl/fattn-buffers.cpp b/ggml/src/ggml-sycl/fattn-buffers.cpp
new file mode 100644
index 000000000000..46cf6d551f17
--- /dev/null
+++ b/ggml/src/ggml-sycl/fattn-buffers.cpp
@@ -0,0 +1,56 @@
+//
+// MIT license
+// Copyright (C) 2025 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include "common.hpp"
+
+sycl::half * ggml_sycl_fattn_kv_buffers::kv_buffer::ensure_half(size_t n_elems) {
+ const size_t need_bytes = n_elems * sizeof(sycl::half);
+
+ if (capacity >= need_bytes) {
+ return ptr;
+ }
+
+ if (ptr) {
+ SYCL_CHECK(CHECK_TRY_ERROR(qptr->wait()));
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
+ ptr = nullptr;
+ capacity = 0;
+ }
+
+ size_t cap = 0;
+ while (cap < need_bytes) {
+ cap += CHUNK_SIZE;
+ }
+
+ void * dev_ptr;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(dev_ptr = sycl::malloc_device(
+ cap, *qptr)));
+
+ if (!dev_ptr) {
+ GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, cap);
+ GGML_ABORT("fattn buffer alloc failed");
+ }
+
+ ptr = static_cast(dev_ptr);
+ capacity = cap;
+ return ptr;
+}
+
+ggml_sycl_fattn_kv_buffers::kv_buffer::~kv_buffer() {
+#ifdef DEBUG_SYCL_POOL
+ GGML_LOG_INFO("ggml_sycl_fattn_kv_buffer[%d]: %.2f MiB\n", device, capacity / 1024.0 / 1024.0);
+#endif
+ if (ptr) {
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
+ }
+}
diff --git a/ggml/src/ggml-sycl/fattn-buffers.hpp b/ggml/src/ggml-sycl/fattn-buffers.hpp
new file mode 100644
index 000000000000..c00461de620f
--- /dev/null
+++ b/ggml/src/ggml-sycl/fattn-buffers.hpp
@@ -0,0 +1,63 @@
+//
+// MIT license
+// Copyright (C) 2025 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_FATTN_BUFFERS_HPP
+#define GGML_SYCL_FATTN_BUFFERS_HPP
+
+#include
+
+typedef sycl::queue *queue_ptr;
+
+struct ggml_sycl_fattn_kv_buffers {
+ // buffers grow in chunks of this size
+ static constexpr size_t CHUNK_SIZE = 16ull << 20; // 16 MiB
+
+ struct kv_buffer {
+ kv_buffer(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
+ ~kv_buffer();
+
+ kv_buffer(const kv_buffer &) = delete;
+ kv_buffer & operator=(const kv_buffer &) = delete;
+
+ sycl::half * ensure_half(size_t n_elems);
+
+ private:
+ sycl::half * ptr = nullptr;
+ size_t capacity = 0;
+ queue_ptr qptr = nullptr;
+ [[maybe_unused]] int device = 0;
+ };
+
+ kv_buffer K;
+ kv_buffer V;
+
+ ggml_sycl_fattn_kv_buffers(queue_ptr qptr, int device) : K(qptr, device), V(qptr, device) {}
+
+ ggml_sycl_fattn_kv_buffers(const ggml_sycl_fattn_kv_buffers &) = delete;
+ ggml_sycl_fattn_kv_buffers & operator=(const ggml_sycl_fattn_kv_buffers &) = delete;
+};
+
+/**
+ * Imitates `ggml_sycl_pool_alloc` to keep the code calling alloc unchanged.
+ */
+struct ggml_sycl_fattn_alloc {
+ ggml_sycl_fattn_kv_buffers::kv_buffer & buf;
+ sycl::half * ptr = nullptr;
+
+ explicit ggml_sycl_fattn_alloc(ggml_sycl_fattn_kv_buffers::kv_buffer & buf_) : buf(buf_) {}
+
+ sycl::half * alloc(size_t n_elems) {
+ ptr = buf.ensure_half(n_elems);
+ return ptr;
+ }
+};
+#endif
diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp
index ed00d03c3b67..03f0c2623c84 100644
--- a/ggml/src/ggml-sycl/fattn-common.hpp
+++ b/ggml/src/ggml-sycl/fattn-common.hpp
@@ -5,6 +5,7 @@
#include "common.hpp"
#include "convert.hpp"
#include "vecdotq.hpp"
+#include "fattn-buffers.hpp"
#include "ggml.h"
@@ -918,12 +919,13 @@ void launch_fattn(
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
ggml_sycl_pool & pool = ctx.pool();
+ ggml_sycl_fattn_kv_buffers & fbuf = ctx.fattn_buffers();
dpct::queue_ptr main_stream = ctx.stream();
const int id = ggml_sycl_get_device();
const int nsm = ggml_sycl_info().devices[id].nsm;
- ggml_sycl_pool_alloc K_f16(pool);
- ggml_sycl_pool_alloc V_f16(pool);
+ ggml_sycl_fattn_alloc K_f16(fbuf.K);
+ ggml_sycl_fattn_alloc V_f16(fbuf.V);
ggml_sycl_pool_alloc KV_max(pool);
ggml_sycl_pool_alloc dst_tmp(pool);
ggml_sycl_pool_alloc dst_tmp_meta(pool);
diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp
index 03f8dd907485..ca4574547756 100644
--- a/ggml/src/ggml-sycl/getrows.cpp
+++ b/ggml/src/ggml-sycl/getrows.cpp
@@ -183,6 +183,10 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
+ case GGML_TYPE_BF16:
+ get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::ext::oneapi::bfloat16 *)dst->src[0]->data,
+ src1_i32, (float *)dst->data, ctx.stream());
+ break;
case GGML_TYPE_F32:
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index 29ecedb5de9e..e7768b8bf614 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -1286,6 +1286,23 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {}
~ggml_sycl_pool_leg() {
+#ifdef DEBUG_SYCL_POOL
+ int n_cached = 0;
+ size_t bytes_cached = 0;
+ for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
+ if (buffer_pool[i].ptr != nullptr) {
+ ++n_cached;
+ bytes_cached += buffer_pool[i].size;
+ }
+ }
+ GGML_LOG_INFO("%s: %d buffers, cached = %.2f MiB\n", __func__,
+ n_cached, bytes_cached / 1024.0 / 1024.0);
+ const auto slots = format_slots_in_alloc_order();
+ if (!slots.empty()) {
+ GGML_LOG_INFO("%s: slots MiB: %s\n", __func__, slots.c_str());
+ }
+#endif
+
for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
ggml_sycl_buffer & b = buffer_pool[i];
if (b.ptr != nullptr) {
@@ -1296,6 +1313,26 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
GGML_ASSERT(pool_size == 0);
}
+#ifdef DEBUG_SYCL_POOL
+ std::string format_slots_in_alloc_order() const {
+ std::string line;
+ char buf[32];
+ bool first = true;
+ for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
+ if (buffer_pool[i].ptr == nullptr) {
+ continue;
+ }
+ if (!first) {
+ line += '/';
+ }
+ first = false;
+ snprintf(buf, sizeof(buf), "%.2f", buffer_pool[i].size / 1024.0 / 1024.0);
+ line += buf;
+ }
+ return line;
+ }
+#endif
+
void * alloc(size_t size, size_t * actual_size) override {
#ifdef DEBUG_sycl_MALLOC
int nnz = 0;
@@ -1459,6 +1496,10 @@ std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(q
return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device));
}
+std::unique_ptr ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) {
+ return std::unique_ptr(new ggml_sycl_fattn_kv_buffers(qptr, device));
+}
+
// TBD pool with virtual memory management
// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
@@ -3303,6 +3344,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
case GGML_TYPE_Q8_0:
return true;
case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return !g_ggml_sycl_prioritize_dmmv;
default:
@@ -3325,6 +3367,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return true;
default:
@@ -3541,6 +3584,54 @@ static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
return true;
}
+static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
+ GGML_ASSERT(size % sizeof(block_q5_K) == 0);
+ GGML_ASSERT(offset % sizeof(block_q5_K) == 0);
+
+ const int nblocks = size / sizeof(block_q5_K);
+
+ sycl_reorder_temp_buffer tmp(stream, size);
+ if (!tmp) {
+ GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
+ return false;
+ }
+ uint8_t * tmp_buf = static_cast(tmp.ptr);
+
+ sycl::event copy_event;
+ SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
+ if (!g_ggml_sycl_use_async_mem_op) {
+ copy_event.wait();
+ }
+
+ auto * qs_ptr = data_device;
+ auto * qh_ptr = qs_ptr + (QK_K / 2) * nblocks;
+ auto * scales_ptr = qh_ptr + (QK_K / 8) * nblocks;
+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
+
+ auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
+ const block_q5_K * x = (const block_q5_K *) tmp_buf;
+ const int ib = i;
+
+ for (int j = 0; j < QK_K / 2; ++j) {
+ qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
+ }
+
+ for (int j = 0; j < QK_K / 8; ++j) {
+ qh_ptr[ib * (QK_K / 8) + j] = x[ib].qh[j];
+ }
+
+ for (int j = 0; j < K_SCALE_SIZE; ++j) {
+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
+ }
+
+ dm_ptr[ib] = x[ib].dm;
+ });
+ if (!g_ggml_sycl_use_async_mem_op) {
+ reorder_event.wait_and_throw();
+ }
+ return true;
+}
+
static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
@@ -3607,6 +3698,8 @@ static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream);
case GGML_TYPE_Q4_K:
return reorder_qw_q4_k(data_device, size, 0, stream);
+ case GGML_TYPE_Q5_K:
+ return reorder_qw_q5_k(data_device, size, 0, stream);
case GGML_TYPE_Q6_K:
return reorder_qw_q6_k(data_device, size, 0, stream);
default:
@@ -4922,6 +5015,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
{
switch (op->src[0]->type) {
case GGML_TYPE_F16:
+ case GGML_TYPE_BF16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
@@ -5104,11 +5198,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ACC:
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
case GGML_OP_PAD:
- // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
if (ggml_get_op_params_i32(op, 8) != 0) {
return false;
}
- return ggml_is_contiguous(op->src[0]);
+ return true;
case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:
diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp
index 8fa2198f35af..49998f13ba8b 100644
--- a/ggml/src/ggml-sycl/mmvq.cpp
+++ b/ggml/src/ggml-sycl/mmvq.cpp
@@ -839,6 +839,26 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
}
}
+static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
+ const int nrows, dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+
+ const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
+ constexpr size_t num_subgroups = 16;
+ GGML_ASSERT(block_num_y % num_subgroups == 0);
+
+ const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
+ const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
+
+ stream->submit([&](sycl::handler & cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
+ [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_reorder>(vx, vy, dst, ncols,
+ nrows, nd_item);
+ });
+ });
+}
+
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
@@ -1125,6 +1145,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n");
reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
} else {
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n");
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
}
break;
@@ -1145,7 +1166,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
}
break;
case GGML_TYPE_Q5_K:
- mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
+ ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
+ GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n");
+ reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ } else {
+ GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n");
+ mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ }
break;
case GGML_TYPE_Q6_K:
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
diff --git a/ggml/src/ggml-sycl/pad.cpp b/ggml/src/ggml-sycl/pad.cpp
index f989c5e4b8bb..ee93bb518016 100644
--- a/ggml/src/ggml-sycl/pad.cpp
+++ b/ggml/src/ggml-sycl/pad.cpp
@@ -13,7 +13,8 @@
//#include "common.hpp"
#include "pad.hpp"
-static void pad_f32(const float * src, float * dst,
+static void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03,
+ float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3,
@@ -27,7 +28,6 @@ static void pad_f32(const float * src, float * dst,
return;
}
- // operation
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
(i1 >= lp1 && i1 < ne1 - rp1) &&
@@ -37,12 +37,8 @@ static void pad_f32(const float * src, float * dst,
const int64_t i01 = i1 - lp1;
const int64_t i02 = i2 - lp2;
const int64_t i03 = i3 - lp3;
- const int64_t ne02 = ne2 - lp2 - rp2;
- const int64_t ne01 = ne1 - lp1 - rp1;
- const int64_t ne00 = ne0 - lp0 - rp0;
- const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +
- i02 * (ne00 * ne01) + i01 * ne00 + i00;
+ const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
dst[dst_idx] = src[src_idx];
} else {
@@ -50,20 +46,19 @@ static void pad_f32(const float * src, float * dst,
}
}
-static void pad_f32_sycl(const float *src, float *dst, const int lp0,
- const int rp0, const int lp1, const int rp1,
- const int lp2, const int rp2, const int lp3,
- const int rp3, const int ne0, const int ne1,
- const int ne2, const int ne3,
+static void pad_f32_sycl(const float * src, size_t s00, size_t s01, size_t s02, size_t s03,
+ float * dst, const int lp0, const int rp0, const int lp1, const int rp1,
+ const int lp2, const int rp2, const int lp3, const int rp3,
+ const int ne0, const int ne1, const int ne2, const int ne3,
dpct::queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
- dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);
+ sycl::range<3> grid(ne2 * ne3, ne1, num_blocks);
stream->parallel_for(
- sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
+ sycl::nd_range<3>(grid * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
- pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
- ne2, ne3, item_ct1);
+ pad_f32(src, s00, s01, s02, s03, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
+ ne0, ne1, ne2, ne3, item_ct1);
});
}
@@ -71,22 +66,27 @@ void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
- dpct::queue_ptr stream = ctx.stream();
+ dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_is_contiguous(src0));
- const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
- const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
- const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
- const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
- const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
- const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
- const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
- const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
+ const size_t ts = ggml_type_size(src0->type);
+ const size_t s00 = src0->nb[0] / ts;
+ const size_t s01 = src0->nb[1] / ts;
+ const size_t s02 = src0->nb[2] / ts;
+ const size_t s03 = src0->nb[3] / ts;
- pad_f32_sycl(src0_d, dst_d,
+ const int32_t lp0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t rp0 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t lp1 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t rp1 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t lp2 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t rp2 = ((const int32_t *)(dst->op_params))[5];
+ const int32_t lp3 = ((const int32_t *)(dst->op_params))[6];
+ const int32_t rp3 = ((const int32_t *)(dst->op_params))[7];
+
+ pad_f32_sycl(src0_d, s00, s01, s02, s03, dst_d,
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}
diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp
index 1f5b62740a8e..806028ef3a32 100644
--- a/ggml/src/ggml-sycl/quants.hpp
+++ b/ggml/src/ggml-sycl/quants.hpp
@@ -79,6 +79,31 @@ template <> struct block_q_t {
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
};
+template <> struct block_q_t {
+ struct traits {
+ static constexpr uint32_t qk = QK_K;
+ static constexpr uint32_t qi = QI5_K;
+ static constexpr uint32_t qr = QR5_K;
+ static constexpr uint32_t vdr_mmvq = 2;
+ };
+
+ // Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales] [dm]
+ static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) {
+ auto qs_offset = block_index * (QK_K / 2);
+ auto qh_offset = n_blocks * (QK_K / 2) + block_index * (QK_K / 8);
+ return { qs_offset, qh_offset };
+ }
+
+ static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) {
+ auto nblocks = (nrows * (ncols / QK_K));
+ auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 8);
+ return { total_qs_bytes + block_index * K_SCALE_SIZE,
+ total_qs_bytes + nblocks * K_SCALE_SIZE + block_index * sizeof(ggml_half2) };
+ }
+
+ static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
+};
+
template <> struct block_q_t {
struct traits {
static constexpr uint32_t qk = QK_K;
diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp
index 9253168e5ea2..d77700474247 100644
--- a/ggml/src/ggml-sycl/vecdotq.hpp
+++ b/ggml/src/ggml-sycl/vecdotq.hpp
@@ -357,38 +357,31 @@ template <> struct reorder_vec_dot_q_sycl {
using q8_0_block = ggml_sycl_reordered::block_q_t;
using q8_0_traits = typename q8_0_block::traits;
- __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) {
- int sumi = 0;
-
-#pragma unroll
- for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
- // Q8_0 values are signed int8, no nibble extraction needed
- // Direct dp4a: each int packs 4 int8 values
- sumi = dpct::dp4a(v[i], u[i], sumi);
- }
-
- const sycl::float2 ds8f = ds8.convert();
-
- // Q8_0 has no bias term (values are signed), so just scale
- return d8_0 * sumi * ds8f.x();
- }
-
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset,
const std::pair d_offset, const int8_t * q8_1_quant_ptr,
const sycl::half2 * q8_1_ds, const int & iqs) {
- const int8_t * bq8_0 = static_cast(vbq) + ibx_offset.first;
- const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset.first));
- int v[q8_0_traits::vdr_mmvq];
- int u[q8_0_traits::vdr_mmvq];
+ const uint8_t * base = static_cast(vbq);
+ const int8_t * qs = reinterpret_cast(base + ibx_offset.first);
+ const ggml_half d = *reinterpret_cast(base + d_offset.first);
+
+ int v[q8_0_traits::vdr_mmvq];
+ int u[q8_0_traits::vdr_mmvq];
#pragma unroll
for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
- v[i] = get_int_from_int8(bq8_0, iqs + i);
+ v[i] = get_int_from_int8(qs, iqs + i);
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i);
}
- return vec_dot_q8_0_q8_1_impl(v, u, d, *q8_1_ds);
- };
+ int sumi = 0;
+#pragma unroll
+ for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) {
+ sumi = dpct::dp4a(v[i], u[i], sumi);
+ }
+
+ const sycl::half2 ds_values = *q8_1_ds;
+ return static_cast(d) * static_cast(ds_values[0]) * sumi;
+ }
};
static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
@@ -481,6 +474,65 @@ template <> struct reorder_vec_dot_q_sycl {
}
};
+template <> struct reorder_vec_dot_q_sycl {
+ static constexpr ggml_type gtype = GGML_TYPE_Q5_K;
+
+ using q5_k_block = ggml_sycl_reordered::block_q_t;
+ using q5_k_traits = typename q5_k_block::traits;
+
+ __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset,
+ const std::pair d_offset, const int8_t * q8_1_quant_ptr,
+ const sycl::half2 * q8_1_ds, const int & iqs) {
+ const uint8_t * base = static_cast(vbq);
+ const uint8_t * qs = base + ibx_offset.first; // low 4 bits
+ const uint8_t * qh_base = base + ibx_offset.second; // high bit
+ const uint8_t * scs = base + d_offset.first;
+ const ggml_half2 * dms = reinterpret_cast(base + d_offset.second);
+
+ const int bq8_offset = QR5_K * ((iqs / 2) / (QI8_1 / 2));
+ const int * ql_ptr = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
+ const int * qh_ptr = (const int *) (qh_base + 4 * ((iqs / 2) % 4));
+ const uint16_t * scales = (const uint16_t *) scs;
+
+ int vl[2];
+ int vh[2];
+ int u[2 * QR5_K];
+ float d8[QR5_K];
+
+ vl[0] = ql_ptr[0];
+ vl[1] = ql_ptr[4];
+
+ vh[0] = qh_ptr[0] >> bq8_offset;
+ vh[1] = qh_ptr[4] >> bq8_offset;
+
+ uint16_t aux[2];
+ const int j = (QR5_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
+ if (j < 2) {
+ aux[0] = scales[j + 0] & 0x3f3f;
+ aux[1] = scales[j + 2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
+ }
+
+ const uint8_t * sc = (const uint8_t *) aux;
+ const uint8_t * m = sc + 2;
+
+ for (int i = 0; i < QR5_K; ++i) {
+ const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
+ sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
+
+ d8[i] = ds_values[0];
+
+ const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4);
+ u[2 * i + 0] = q8[0];
+ u[2 * i + 1] = q8[4];
+ }
+
+ return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, *dms, d8);
+ }
+};
+
template <> struct reorder_vec_dot_q_sycl {
static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 163f222ef612..f43cf546ca08 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -503,6 +503,14 @@ struct llm_tokenizer_bpe : llm_tokenizer {
};
byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding
break;
+ case LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE:
+ // Sarvam uses SPM-style BPE (same shape as Gemma4): spaces replaced with U+2581
+ // by the normalizer, BPE merges over the whole text on raw UTF-8.
+ regex_exprs = {
+ "[^\\n]+|[\\n]+",
+ };
+ byte_encode = false;
+ break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@@ -2005,6 +2013,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "gemma4") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4;
escape_whitespaces = true;
+ } else if (
+ tokenizer_pre == "sarvam-moe") {
+ pre_type = LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE;
+ escape_whitespaces = true;
+ clean_spaces = false;
} else if (
tokenizer_pre == "jina-v1-en" ||
tokenizer_pre == "jina-v2-code" ||
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index dd38f45d3a22..8b040b912e2f 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -59,6 +59,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49,
LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50,
+ LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51,
};
struct LLM_KV;
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index a55b5b4c2337..922ad493a34d 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -8861,8 +8861,10 @@ static std::vector> make_test_cases_eval() {
if (nh == 1 && hsk != 320 && hsk != 576) continue;
for (int nr3 : { 1, 3, }) {
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
- for (int nr2 : { 1, 4, 12, 20, 32 }) {
+ for (int nr2 : { 1, 4, 8, 12, 16, 20, 32 }) {
+ if (nr2 == 8 && hsk != 192) continue;
if (nr2 == 12 && hsk != 128) continue;
+ if (nr2 == 16 && hsk != 192) continue;
if (nr2 == 20 && (nh != 1 || hsk != 576)) continue;
if (nr2 == 32 && (nh != 1 || hsk != 320)) continue;
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt
index df4b9ecce3f2..dbc420d83704 100644
--- a/vendor/cpp-httplib/CMakeLists.txt
+++ b/vendor/cpp-httplib/CMakeLists.txt
@@ -41,7 +41,7 @@ if (LLAMA_BUILD_BORINGSSL)
set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)")
set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository")
- set(BORINGSSL_VERSION "0.20260413.0" CACHE STRING "BoringSSL version")
+ set(BORINGSSL_VERSION "0.20260508.0" CACHE STRING "BoringSSL version")
message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}")