Skip to content

Commit e4a2036

Browse files
authored
Merge pull request #73 from SharpAI/fix/gemma4-quantizedkv-b440
fix: Gemma-4 QuantizedKVCache + kv_bits API + Test 9 (mlx-swift-lm b440)
2 parents 116ee91 + ed5f8f6 commit e4a2036

4 files changed

Lines changed: 313 additions & 26 deletions

File tree

README.md

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,25 +89,76 @@ Benchmark results for `gemma-4-26b-a4b-it-4bit` (26B MoE, 4-bit) on M5 Pro 64 GB
8989

9090
---
9191

92-
## 🧠 Supported Models & Methodologies
92+
## 📡 Supported Models & Methodologies
9393

94-
`SwiftLM` dynamically maps Apple MLX primitives to standard HuggingFace architectures, enabling complete support for the latest frontier open-weights models across modalities (Text, Vision, Audio).
94+
`SwiftLM` dynamically maps Apple MLX primitives to standard HuggingFace architectures, enabling native Metal inference across the latest frontier open-weights models.
9595

96-
### Text (LLMs)
97-
- **Gemma 4**: Fully supports both Dense (`gemma-4-e4b`) and Sparse Mixture of Experts (MoE) architectures (`gemma-4-26b`, `gemma-4-31b`).
98-
- **Qwen 2.5 & 3**: Robust support for sliding window attention limits and custom RoPE scaling.
99-
- **Mistral & Mixtral**: Out-of-the-box structural mappings.
100-
- **Phi-3 & Phi-3.5**: Full 128k context parsing via Swift chunked-prefill.
96+
### 💬 Text (LLMs)
10197

102-
### Vision (VLMs)
98+
| Family | Models | Notes |
99+
|---|---|---|
100+
| **Gemma 4** | `gemma-4-e2b`, `gemma-4-e4b` (dense) · `gemma-4-26b-a4b`, `gemma-4-31b` (MoE) | Interleaved local + global attention; KV sharing; native quantized KV cache (issue #71 fix) |
101+
| **Gemma 3 / 3n** | `gemma-3-*`, `gemma-3n-*` | Google Gemma 3 and nano variants |
102+
| **Gemma / Gemma 2** | `gemma-*`, `gemma-2-*` | Original Gemma family |
103+
| **Qwen 3.5** | `Qwen3.5-7B`, `Qwen3.5-27B`, `Qwen3.5-122B-A10B`, `Qwen3.5-397B-A22B` | Dense + MoE; SSD streaming at 10× for 122B/397B |
104+
| **Qwen 3** | `Qwen3-*` (dense + MoE) | Sliding window + hybrid attention |
105+
| **Qwen 2.5** | `Qwen2.5-7B`, `Qwen2.5-14B`, `Qwen2.5-72B` | Robust RoPE scaling |
106+
| **Qwen 2** | `Qwen2-*` | Linear RoPE variants |
107+
| **Phi 4 / PhiMoE** | `phi-4-mlx`, `Phi-3.5-MoE` | Microsoft Phi family incl. MoE |
108+
| **Phi 3 / Phi** | `Phi-3`, `Phi-3.5-mini` | 128k context via chunked prefill |
109+
| **Mistral / Mixtral** | `Mistral-7B`, `Mistral-4`, `Mixtral-*` | GQA + sliding window variants |
110+
| **Llama / Llama 3** | `Llama-3.1-*`, `Llama-3.2-*`, `Llama-3.3-*` | YaRN + dynamic NTK RoPE scaling |
111+
| **GLM 4** | `GLM-4-*` | THUDM GLM-4 dense + MoE-Lite variants |
112+
| **DeepSeek V3** | `DeepSeek-V3-*` | MLA attention architecture |
113+
| **Falcon H1** | `Falcon-H1-*` | Falcon hybrid SSM+attention |
114+
| **LFM 2** | `LFM2-*`, `LFM2-MoE-*` | Liquid AI dense + MoE |
115+
| **OLMo 2 / OLMo 3 / OLMoE** | `OLMo-2-*`, `OLMo-3-*` | AllenAI open language models |
116+
| **Granite / GraniteMoE** | `Granite-*`, `GraniteMoE-Hybrid-*` | IBM Granite hybrid Mamba+attention |
117+
| **SmolLM 3** | `SmolLM3-*` | HuggingFace compact LM |
118+
| **MiniCPM** | `MiniCPM-*` | Lightweight efficient LM |
119+
| **InternLM 2** | `InternLM2-*` | Shanghai AI Lab series |
120+
| **Cohere / Command-R** | `Command-R-*`, `c4ai-*` | Cohere retrieval-tuned models |
121+
| **Jamba** | `Jamba-v0.1` | AI21 hybrid Mamba+attention |
122+
| **Exaone 4** | `EXAONE-4.0-*` | LG AI Research |
123+
| **MiMo / MiMo V2** | `MiMo-7B-*` | Xiaomi reasoning model |
124+
| **Ernie 4.5** | `ERNIE-4.5-*` | Baidu ERNIE series |
125+
| **Baichuan M1** | `Baichuan-M1-*` | Baichuan multimodal base |
126+
| **Bailing MoE** | `Ling-*` | Bailing/Ling MoE family |
127+
| **NemotronH** | `Nemotron-H-*` | NVIDIA Nemotron hybrid |
128+
| **Starcoder 2** | `starcoder2-*` | Code generation |
129+
| **OpenELM** | `OpenELM-*` | Apple on-device efficient LM |
130+
| **Apertus / AfMoE** | `Apertus-*` | Sparse MoE research models |
131+
| **BitNet** | `bitnet-*` | 1-bit weight quantization |
132+
| **MiniMax** | `MiniMax-Text-*` | Lightning attention architecture |
133+
| **Olmo3** | `Olmo3-*` | AllenAI Olmo3 series |
134+
135+
### 👁️ Vision (VLMs)
103136
*Run with `--vision` flag.*
104-
- **Qwen2-VL & Qwen3-VL**: Real-time positional bounding and Metal image scaling.
105-
- **PaliGemma / LFM2-VL / Pixtral**: Base64 spatial decomposition.
106137

107-
### Audio (ALMs)
108-
*Run with `--audio` flag.*
109-
- **Qwen2-Audio (7B-Instruct)**: Deep multi-modal spectrogram processing via Swift audio interleaving.
110-
- **Gemma-4 Audio Pipelines**: Ready for Audio-in/Text-out variants mapping `.audio_tower` extraction parameters natively off NVMe.
138+
| Family | Models | Notes |
139+
|---|---|---|
140+
| **Gemma 4** | `gemma-4-*` (VLM mode) | Native image tower via MLXVLM |
141+
| **Gemma 3** | `gemma-3-*` (VLM mode) | PaLiGemma-style image projection |
142+
| **Qwen3-VL / Qwen3.5-VL** | `Qwen3-VL-*`, `Qwen3.5-VL-*` | Dynamic resolution with native RoPE |
143+
| **Qwen2-VL / Qwen2.5-VL** | `Qwen2-VL-2B/7B`, `Qwen2.5-VL-*` | Real-time positional bounding + Metal image scaling |
144+
| **LFM2-VL** | `LFM2-VL-1.6B` | Liquid AI multimodal |
145+
| **Pixtral** | `pixtral-12b` | Mistral vision model |
146+
| **PaliGemma** | `paligemma-*` | Google vision-language |
147+
| **Idefics 3** | `Idefics3-*` | HuggingFace multimodal |
148+
| **Mistral 3** | `Mistral-Small-3.1-*` | Mistral vision variant |
149+
| **FastVLM** | `FastVLM-*` | Apple on-device VLM |
150+
| **SmolVLM 2** | `SmolVLM2-*` | HuggingFace compact VLM |
151+
| **GLM OCR** | `glm-4v-*` | THUDM vision+OCR |
152+
| **QwenVL** | `Qwen-VL-*` | Original Qwen VL |
153+
154+
### 🎧 Audio (ALMs)
155+
*Run with `--audio` flag. Only `gemma-4-e4b` variants include an audio tower.*
156+
157+
| Family | Models | Notes |
158+
|---|---|---|
159+
| **Gemma 4 Omni** | `gemma-4-e4b-it-4bit`, `gemma-4-e4b-it-8bit` | Audio-in via vDSP STFT → Mel spectrogram (16kHz, 128 bins); text-out |
160+
161+
111162

112163
---
113164

@@ -352,10 +403,46 @@ curl http://localhost:5413/v1/chat/completions \
352403
| `--min-p` | `0.0` | Default min-p sampling threshold relative to the highest probability token (0 disables) |
353404
| `--gpu-layers` | `model_default`| Restrict the amount of layers allocated to GPU hardware |
354405
| `--stream-experts` | `false` | Enable SSD expert streaming for MoE models (10x speedup) |
355-
| `--turbo-kv` | `false` | Enable TurboQuant 3-bit KV cache compression |
406+
| `--turbo-kv` | `false` | Enable TurboQuant 3-bit KV cache compression (activates after 2048 tokens, server-wide) |
356407
| `--draft-model` | (none) | Draft model path/ID for speculative decoding (in-RAM models only) |
357408
| `--num-draft-tokens` | `4` | Number of draft tokens per speculation round |
358409

410+
## 🔧 Per-Request API Parameters
411+
412+
In addition to the standard OpenAI fields (`temperature`, `top_p`, `max_tokens`, etc.), SwiftLM accepts the following **SwiftLM-specific** fields on `POST /v1/chat/completions`:
413+
414+
| Field | Type | Description |
415+
|---|---|---|
416+
| `kv_bits` | `int` (4 or 8) | Enable **MLX-native quantized KV cache** for this request. Uses `QuantizedKVCache` (standard group quantization) instead of `KVCacheSimple`. Separate from `--turbo-kv`. Reduces KV memory ~2–4× at mild quality cost. |
417+
| `enable_thinking` | `bool` | Force-enable or disable chain-of-thought thinking blocks for Gemma-4 / Qwen3. |
418+
| `kv_group_size` | `int` | Group size for `kv_bits` quantization (default: `64`). |
419+
| `top_k` | `int` | Per-request top-k sampling override (0 = disabled). |
420+
| `min_p` | `float` | Per-request min-p sampling threshold (0 = disabled). |
421+
| `repetition_penalty` | `float` | Token repetition penalty (e.g. `1.15`). |
422+
423+
### `kv_bits` vs `--turbo-kv` — What's the difference?
424+
425+
| | `kv_bits` (per-request) | `--turbo-kv` (server flag) |
426+
|---|---|---|
427+
| **Scope** | Per-request, sent in JSON body | Server-wide, set at startup |
428+
| **Algorithm** | MLX-native group quantization (4-bit / 8-bit) | Custom 3-bit PolarQuant + QJL Walsh-Hadamard |
429+
| **Activation** | From token 0 | After 2048 tokens |
430+
| **Memory savings** | ~2–4× vs FP16 | ~3.5× vs FP16 |
431+
| **Use case** | Targeted memory reduction per conversation | Extreme long-context (100K+) compression |
432+
433+
### Example: Enable 4-bit KV cache per request
434+
```bash
435+
curl http://localhost:5413/v1/chat/completions \\
436+
-H "Content-Type: application/json" \\
437+
-d '{
438+
"model": "gemma-4-26b-a4b-it-4bit",
439+
"kv_bits": 4,
440+
"messages": [
441+
{"role": "user", "content": "Summarize the history of computing in 3 sentences."}
442+
]
443+
}'
444+
```
445+
359446
## 📦 Requirements
360447

361448
- macOS 14.0+

Sources/SwiftLM/Server.swift

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,9 +1048,20 @@ func handleChatCompletion(
10481048
// These are accepted but may not affect generation if MLX doesn't support them
10491049
}
10501050

1051+
// ── Validate kv_bits: only nil, 4, and 8 are supported ──
1052+
if let kb = chatReq.kvBits, kb != 4 && kb != 8 {
1053+
let errBody = "{\"error\":{\"message\":\"Invalid kv_bits value \(kb). Supported values are 4 and 8.\",\"type\":\"invalid_request_error\",\"code\":\"invalid_kv_bits\"}}"
1054+
return Response(
1055+
status: .badRequest,
1056+
headers: jsonHeaders(),
1057+
body: .init(byteBuffer: ByteBuffer(string: errBody))
1058+
)
1059+
}
1060+
10511061
let params = GenerateParameters(
10521062
maxTokens: tokenLimit,
10531063
maxKVSize: config.ctxSize,
1064+
kvBits: chatReq.kvBits,
10541065
temperature: temperature,
10551066
topP: topP,
10561067
topK: topK,
@@ -1200,9 +1211,13 @@ func handleChatCompletion(
12001211
// raw <|image|>/<|audio|> token embeddings instead of the projected features.
12011212
let isMultimodalRequest = lmInput.image != nil || lmInput.audio != nil
12021213

1203-
// Try to restore via token-by-token prefix match (llama-server style)
1214+
// Try to restore via token-by-token prefix match (llama-server style).
1215+
// Skip for quantized-KV requests: the prompt cache stores KV state produced
1216+
// with KVCacheSimple; restoring it into a QuantizedKVCache (or vice-versa)
1217+
// is unsafe and produces incorrect results or runtime failures.
1218+
let skipPromptCache = isMultimodalRequest || params.kvBits != nil
12041219
var stream: AsyncStream<Generation>
1205-
if !isMultimodalRequest, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) {
1220+
if !skipPromptCache, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) {
12061221
// Cache hit: KV state is pre-populated up to cachedCount tokens.
12071222
// Only compute the remaining (new) tokens.
12081223
var startIndex = cachedCount
@@ -1251,6 +1266,10 @@ func handleChatCompletion(
12511266
let onPrefillDone: (() async -> Void)? = {
12521267
if turboHasCompressed {
12531268
print("[SwiftLM] 🧠 Skipping prompt cache save — TurboQuant has compressed \(cache.compactMap { ($0 as? KVCacheSimple)?.compressedOffset }.max() ?? 0) tokens. Saving would decode ~37 GB back to fp16.")
1269+
} else if params.kvBits != nil {
1270+
// kv_bits is set: the cache contains QuantizedKVCache layers whose token
1271+
// format is incompatible with the FP16 KVCacheSimple format expected by
1272+
// promptCache.save. Skip saving to prevent unsafe mixed-format restores.
12541273
} else {
12551274
await promptCache.save(tokens: promptTokens, cache: cache)
12561275
}
@@ -2305,6 +2324,10 @@ struct ChatCompletionRequest: Decodable {
23052324
let chatTemplateKwargs: [String: Bool]?
23062325
/// Top-level thinking override emitted by Aegis-AI gateway
23072326
let enableThinking: Bool?
2327+
/// Number of bits for native MLX quantized KV cache (nil = no quantization).
2328+
/// Only 4 and 8 are supported by the underlying MLX QuantizedKVCache.
2329+
/// Enables `QuantizedKVCache` instead of `KVCacheSimple`. Separate from `--turbo-kv`.
2330+
let kvBits: Int?
23082331

23092332
enum CodingKeys: String, CodingKey {
23102333
case model, messages, stream, temperature, tools, stop, seed
@@ -2319,6 +2342,7 @@ struct ChatCompletionRequest: Decodable {
23192342
case responseFormat = "response_format"
23202343
case chatTemplateKwargs = "chat_template_kwargs"
23212344
case enableThinking = "enable_thinking"
2345+
case kvBits = "kv_bits"
23222346
}
23232347
}
23242348

mlx-swift-lm

0 commit comments

Comments
 (0)