Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .devops/intel.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ RUN mkdir -p /app/full \

FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS base

ARG IGC_VERSION=v2.30.1
ARG IGC_VERSION_FULL=2_2.30.1+20950
ARG COMPUTE_RUNTIME_VERSION=26.09.37435.1
ARG COMPUTE_RUNTIME_VERSION_FULL=26.09.37435.1-0
ARG IGC_VERSION=v2.32.7
ARG IGC_VERSION_FULL=2_2.32.7+21184
ARG COMPUTE_RUNTIME_VERSION=26.14.37833.4
ARG COMPUTE_RUNTIME_VERSION_FULL=26.14.37833.4-0
ARG IGDGMM_VERSION=22.9.0
RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-core-${IGC_VERSION_FULL}_amd64.deb \
Expand Down
2 changes: 1 addition & 1 deletion .devops/nix/package.nix
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ let
vulkan-headers
vulkan-loader
shaderc
spirv-headers
];
in

Expand Down Expand Up @@ -146,7 +147,6 @@ effectiveStdenv.mkDerivation (finalAttrs: {
ninja
pkg-config
git
spirv-headers
]
++ optionals useCuda [
cudaPackages.cuda_nvcc
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ uv.lock

# Nix

flake.lock
/result

# Test binaries
Expand Down
31 changes: 31 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "862f827721df956049dff5ca81a57f29e575280bc622e290d3bf4e35eca29015":
# ref: https://huggingface.co/codefuse-ai/F2LLM-v2-4B
res = "f2llmv2"
if chkhsh == "62f6fb0a6fd5098caeabb19b07a5c1099cafc8b9c40eab6ea89ece4ec02fbc57":
# ref: https://huggingface.co/sarvamai/sarvam-30b
res = "sarvam-moe"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -11591,6 +11594,34 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("SarvamMoEForCausalLM", "modeling_sarvam_moe.SarvamMoEForCausalLM")
class SarvamMoEModel(BailingMoeV2Model):
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
# Sarvam-MoE shares the BailingMoeV2 architecture; only differences:
# - full rotary (no partial_rotary_factor)
# - expert bias is zero-mean normalized at load time

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if (rope_dim := hparams.get("head_dim")) is None:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
# Override the partial-rotary value written by BailingMoeV2 with the full rotary dim
self.gguf_writer.add_rope_dimension_count(rope_dim)

@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
if name.endswith(".expert_bias"):
# Sarvam normalizes expert bias to zero mean
inner = gen

def gen():
t = inner()
return t - t.mean()
return super().filter_tensors((name, gen))


@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
class GroveMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROVEMOE
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
{"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", },
{"name": "f2llmv2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/codefuse-ai/F2LLM-v2-4B", },
{"name": "sarvam-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sarvamai/sarvam-30b", },
]

# some models are known to be broken upstream, so we will skip them as exceptions
Expand Down
8 changes: 8 additions & 0 deletions docs/backend/SYCL.md
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|

## Compile-time Flags

Pass these via `CXXFLAGS` or add a one-off `#define` to enable a flag on the spot.

| Name | Function |
|-----------------|----------------------------------------------------------------------------------|
| DEBUG_SYCL_POOL | Enable device memory pool logging on teardown. Useful for profiling allocations. |

## Design Rule

- Open to all contributors.
Expand Down
58 changes: 0 additions & 58 deletions flake.lock

This file was deleted.

9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);

GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true);

GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
Expand Down Expand Up @@ -1561,6 +1566,10 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) {
NO_DEVICE_CODE;
return;
}
#ifdef VOLTA_MMA_AVAILABLE
if (ncols1*ncols2 < 32) {
NO_DEVICE_CODE;
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/fattn-tile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
} break;
case 192: {
GGML_ASSERT(V->ne[0] == 128);
ggml_cuda_flash_attn_ext_tile_case<192, 128>(ctx, dst);
} break;
case 256: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
Expand Down
40 changes: 39 additions & 1 deletion ggml/src/ggml-cuda/fattn-tile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 64, 64)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64)
Expand Down Expand Up @@ -124,6 +130,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 128, 3, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 3, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256)
Expand Down Expand Up @@ -193,6 +205,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
Expand Down Expand Up @@ -264,6 +282,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 8, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 6, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 128, 6, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 5, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 3, 64, 64)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
Expand Down Expand Up @@ -1250,7 +1274,20 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
}
}

if constexpr (DKQ <= 512 && DKQ != 320) {
if constexpr (DKQ == 192) {
// MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn)
if (use_gqa_opt && gqa_ratio % 16 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 8 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
return;
}
GGML_ABORT("flash-attn tile (192/128): expected GQA ratio multiple of 8");
}

if constexpr (DKQ <= 512 && DKQ != 320 && DKQ != 192) {
if (use_gqa_opt && gqa_ratio % 8 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
return;
Expand Down Expand Up @@ -1303,6 +1340,7 @@ extern DECL_FATTN_TILE_CASE( 80, 80);
extern DECL_FATTN_TILE_CASE( 96, 96);
extern DECL_FATTN_TILE_CASE(112, 112);
extern DECL_FATTN_TILE_CASE(128, 128);
extern DECL_FATTN_TILE_CASE(192, 128);
extern DECL_FATTN_TILE_CASE(256, 256);
extern DECL_FATTN_TILE_CASE(320, 256);
extern DECL_FATTN_TILE_CASE(512, 512);
Expand Down
33 changes: 29 additions & 4 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(V->ne[0] == 128);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
break;
case 192: {
// MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn)
GGML_ASSERT(V->ne[0] == 128);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const bool use_gqa_opt = mask && max_bias == 0.0f;
GGML_ASSERT(use_gqa_opt);
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 16>(ctx, dst);
} else {
GGML_ASSERT(gqa_ratio % 8 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 8>(ctx, dst);
}
} break;
case 256:
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
Expand Down Expand Up @@ -368,6 +384,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_NONE;
}
break;
case 192:
if (V->ne[0] != 128 || !gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
if (gqa_ratio % 8 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 320:
if (V->ne[0] != 256 || !gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
Expand Down Expand Up @@ -425,7 +449,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}

// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path.
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0;

// If Turing tensor cores are available, use them:
if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
Expand Down Expand Up @@ -454,7 +479,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const

if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
int gqa_ratio_eff = 1;
const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8;
while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
gqa_ratio_eff *= 2;
}
Expand All @@ -468,7 +493,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}

// Use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) {
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 512 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
Expand Down Expand Up @@ -501,7 +526,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}

// Use MFMA flash attention for CDNA (MI100+):
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
// MMA vs tile crossover benchmarked on MI300X @ d32768:
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

#include "../fattn-mma-f16.cuh"

DECL_FATTN_MMA_F16_CASE(192, 128, 1, 16);
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
DECL_FATTN_MMA_F16_CASE(192, 128, 1, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

#include "../fattn-mma-f16.cuh"

DECL_FATTN_MMA_F16_CASE(192, 128, 2, 16);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
DECL_FATTN_MMA_F16_CASE(192, 128, 2, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

#include "../fattn-mma-f16.cuh"

DECL_FATTN_MMA_F16_CASE(192, 128, 4, 16);
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
DECL_FATTN_MMA_F16_CASE(192, 128, 4, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
Loading
Loading