Skip to content

perf(q6_k): SIMD-fy dequantQ6_KBlock via ByteVector ql + qh extraction#564

Merged
michalharakal merged 1 commit intodevelopfrom
feature/jvm-q6k-simd-dequant
Apr 28, 2026
Merged

perf(q6_k): SIMD-fy dequantQ6_KBlock via ByteVector ql + qh extraction#564
michalharakal merged 1 commit intodevelopfrom
feature/jvm-q6k-simd-dequant

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Summary

Replaces the scalar 32-iteration inner loop in dequantQ6_KBlock — the dequant step under matmulQ6_KVec — with a fused ByteVector pipeline. Per floatStep-wide chunk of l:

  1. Load three byte slices: ql[qlBase + l], ql[qlBase + l + 32], qh[qhBase + l]
  2. Assemble four 6-bit codes per lane:
    • q1 = (ql0 & 0x0F) | ((qh & 0x03) << 4) − 32
    • q2 = (ql32 & 0x0F) | ((qh >>> 2 & 0x03) << 4) − 32
    • q3 = (ql0 >>> 4) | ((qh >>> 4 & 0x03) << 4) − 32
    • q4 = (ql32 >>> 4) | (qh >>> 6 << 4) − 32
  3. Widen to FloatVector, multiply by per-sub-block d·scale, store to four 32-element regions of the scratch FloatArray.

Inline replacement — matmulQ6_KVec's outer structure (scratch + SIMD dot product) is unchanged.

Why this matters

Q6_K is the dominant format for the embedding, lm_head, and FFN matrices of Gemma 4 E2B Q4_K_M. The scalar dequant was the hot path inside the per-cell, per-block loop under matmulQ6_KVec. Closing it brings Q6_K up to the SIMD pipeline standard set by Q4_K (#562/#563) and Q8_0 (already SIMD).

Scope

  • Inline replacement in JvmQuantizedVectorKernels.dequantQ6_KBlock. No new SPI surface (full Q6KMatmulKernel SPI is a fair follow-up if a native FFM provider needs to register here).
  • matmulQ6_KVec outer loop / parallelism / scratch-allocation strategy unchanged.

Out of scope (next M5 follow-ups)

  • Full fused Q6_K matmul (eliminate scratch FloatArray entirely) — more involved because sub-block scales differ across l ∈ 0..15 vs 16..31, requiring 16 parallel accumulators per output cell per block.
  • Q6KMatmulKernel sibling SPI (similar shape to Q4KMatmulKernel from feat(kernel): SIMD-fused Q4_K matmul kernel + Q4KMatmulKernel SPI #562).
  • Q4_0 SIMD — interleaved nibble layout makes a clean SIMD pipeline harder than Q4_K/Q6_K.
  • Native (FFM) Q4_K / Q6_K kernel — priority 100, calls hand-tuned NEON/AVX2 via MemorySegment. Closes M5.

Test plan

  • ./gradlew :skainet-backends:skainet-backend-cpu:jvmTest — 218/218 tests pass, including Q6KMatmulTest's parity vs the canonical ggml DequantOps.dequantQ6KFromBytes reference (covers single output row + multi-row × multi-block within 1e-4 relative tolerance).

🤖 Generated with Claude Code

Replaces the scalar 32-iteration inner loop in dequantQ6_KBlock — the
hot path under matmulQ6_KVec — with a fused ByteVector pipeline. Per
floatStep-wide chunk of l: load slices of ql[qlBase+l],
ql[qlBase+l+32], qh[qhBase+l]; assemble q1..q4 = (qlNibble) |
((qhSlice) << 4) − 32 via byte AND/LSHR/OR ops; widen to FloatVector;
multiply by per-sub-block d·scale; store to four 32-element regions
of the scratch FloatArray.

Inline replacement; doesn't change matmulQ6_KVec's outer structure
(scratch FloatArray + SIMD dot product remain). Future: full fused
matmul (no scratch) is a fair follow-up but more involved because
sub-block scales differ across l ∈ 0..15 vs 16..31, requiring 16
parallel accumulators per output cell per block.

Q6_K is the dominant format for Gemma 4 E2B Q4_K_M's embedding,
lm_head, and FFN matrices, so the dequant cost is non-trivial in
real LLM decode.

Tests: cpu jvmTest 218/218 pass, including Q6KMatmulTest's parity vs
the canonical ggml dequant reference (DequantOps.dequantQ6KFromBytes).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal marked this pull request as ready for review April 28, 2026 20:57
@michalharakal michalharakal merged commit 00b80c0 into develop Apr 28, 2026
6 checks passed
@michalharakal michalharakal deleted the feature/jvm-q6k-simd-dequant branch April 28, 2026 20:57
@michalharakal michalharakal mentioned this pull request Apr 28, 2026
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant