Skip to content

Commit d74517e

Browse files
michalharakalclaude
andcommitted
Add BitNet/Ternary quantization support (TQ1_0/TQ2_0)
- Add dequantTQ1_0 and dequantTQ2_0 functions in LlamaWeightLoader - Create TernaryTensorData interface and Ternary2BitTensorData implementation - Implement addition-only TernaryMatmul kernel (no FP multiplies) - Add comprehensive unit tests for dequantization and ternary ops - Update quant_format.md with TQ documentation - Mark BitNet support as implemented in roadmap The ternary matmul enables efficient inference for BitNet-style models where weights are constrained to {-1, 0, +1}, replacing multiplications with conditional additions/subtractions. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c672829 commit d74517e

8 files changed

Lines changed: 1355 additions & 4 deletions

File tree

kllama-enterprise.md

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ KLlama is a Kotlin Multiplatform LLM inference runtime. This document outlines t
2929
| iOS/Android native ||| Via bindings |
3030
| Browser (Wasm) ||| Via bindings |
3131
| Quantized inference | 🚧 Planned |||
32-
| **BitNet/Ternary native** | 🚧 Planned || Partial |
32+
| **BitNet/Ternary native** | ✅ TQ1_0/TQ2_0 dequant + ternary matmul || Partial |
3333
| SIMD optimization | Partial |||
3434
| Memory-mapped I/O | ✅ (JVM) |||
3535
| Multiple architectures ||||
@@ -467,7 +467,7 @@ class MappedGGUFReader(path: Path) {
467467

468468
**Impact**: Enable 7B, 13B, 70B models without OOM
469469

470-
### 1.2 BitNet / Ternary Quantization Support 🆕 HIGH PRIORITY
470+
### 1.2 BitNet / Ternary Quantization Support ✅ IMPLEMENTED
471471

472472
Native support for Microsoft's BitNet 1.58-bit models with ternary weights {-1, 0, +1}.
473473

@@ -477,6 +477,14 @@ Native support for Microsoft's BitNet 1.58-bit models with ternary weights {-1,
477477
- Unique differentiator (most frameworks don't have native ternary kernels)
478478
- We already have `Ternary` DType and `DenseTernaryTensorArray`
479479

480+
**What's Implemented:**
481+
- ✅ TQ1_0 dequantization (base-3 packed ternary format, ~1.69 bpw)
482+
- ✅ TQ2_0 dequantization (2-bit packed ternary format, ~2.06 bpw)
483+
-`Ternary2BitTensorData` - compact storage with TQ format encoding
484+
-`TernaryMatmul.matmul()` - addition-only kernel (no FP multiply)
485+
-`matmulAutoDispatch()` - automatic ternary detection and dispatch
486+
- ✅ Comprehensive unit tests for all components
487+
480488
**Architecture Integration:**
481489
```
482490
┌─────────────────────────────────────────────────────────────────┐
@@ -542,13 +550,18 @@ fun matmulTernarySIMD(input: FloatArray, weights: TernaryTensorData): FloatArray
542550
}
543551
```
544552

545-
**Existing Foundation:**
553+
**Implementation Status:**
546554
| Component | Status | Location |
547555
|-----------|--------|----------|
548556
| `Ternary` DType || `skainet-lang-core/.../types/Ternary.kt` |
549557
| `DenseTernaryTensorArray` || `skainet-lang-core/.../data/dense/` |
550558
| GGUF TQ1_0/TQ2_0 enum || `GGMLQuantizationType` |
551559
| Type promotion || Ternary → Int8 → FP32 |
560+
| `dequantTQ1_0()` || `LlamaWeightLoader.kt` |
561+
| `dequantTQ2_0()` || `LlamaWeightLoader.kt` |
562+
| `Ternary2BitTensorData` || `skainet-lang-core/.../data/TernaryTensorData.kt` |
563+
| `TernaryMatmul` || `skainet-lang-core/.../ops/TernaryMatmul.kt` |
564+
| Unit tests || `LlamaQuantDequantTest`, `TernaryTensorDataTest`, `TernaryMatmulTest` |
552565

553566
**Impact**:
554567
- **Speed**: 5-10x faster than FP32 (no FP multiply, integer add only)

quant_format.md

Lines changed: 378 additions & 0 deletions
Large diffs are not rendered by default.

skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/llama/LlamaWeightLoader.kt

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,123 @@ public class LlamaWeightLoader(
597597
}
598598
return out
599599
}
600+
601+
/**
602+
* Dequantize TQ2_0 (Ternary 2-bit) format to FP32.
603+
*
604+
* TQ2_0 layout per block (256 elements, 66 bytes):
605+
* - 64 bytes: quantized data (4 ternary values per byte, 2-bit each)
606+
* - 2 bytes: f16 scale
607+
*
608+
* Values encoded as {0, 1, 2} represent {-1, 0, +1}.
609+
* Dequantization: output[i] = (ternary[i] - 1) * scale
610+
*/
611+
internal fun dequantTQ2_0(raw: List<Any>, nElems: Int): FloatArray {
612+
val bytes = toByteArray(raw, "TQ2_0")
613+
val blockSize = 256
614+
val bytesPerBlock = 66 // 64 (qs) + 2 (f16 scale)
615+
val blockCount = bytes.size / bytesPerBlock
616+
val out = FloatArray(blockCount * blockSize)
617+
var offset = 0
618+
var outOff = 0
619+
620+
repeat(blockCount) {
621+
// Read quantized values first (64 bytes = 256 values at 2-bit each)
622+
val qs = bytes.copyOfRange(offset, offset + 64)
623+
offset += 64
624+
625+
// Read f16 scale (last 2 bytes)
626+
val scale = halfToFloat(
627+
(bytes[offset + 1].toInt() and 0xFF shl 8) or (bytes[offset].toInt() and 0xFF)
628+
)
629+
offset += 2
630+
631+
// Decode 2-bit values: 4 values per byte
632+
// Bit layout: [v3:v2:v1:v0] where each vN is 2 bits
633+
for (i in 0 until 64) {
634+
val b = qs[i].toInt() and 0xFF
635+
val v0 = (b and 0x03) - 1 // bits 0-1
636+
val v1 = ((b shr 2) and 0x03) - 1 // bits 2-3
637+
val v2 = ((b shr 4) and 0x03) - 1 // bits 4-5
638+
val v3 = ((b shr 6) and 0x03) - 1 // bits 6-7
639+
640+
out[outOff + i * 4 + 0] = v0 * scale
641+
out[outOff + i * 4 + 1] = v1 * scale
642+
out[outOff + i * 4 + 2] = v2 * scale
643+
out[outOff + i * 4 + 3] = v3 * scale
644+
}
645+
outOff += blockSize
646+
}
647+
return out
648+
}
649+
650+
/**
651+
* Dequantize TQ1_0 (Ternary base-3) format to FP32.
652+
*
653+
* TQ1_0 layout per block (256 elements, 54 bytes):
654+
* - 48 bytes: base-3 packed data (5 values per byte, 240 elements total)
655+
* - 4 bytes: 2-bit packed for remaining 16 elements
656+
* - 2 bytes: f16 scale
657+
*
658+
* Base-3 encoding: 5 ternary values packed into one byte (3^5 = 243 < 256).
659+
* Values {0, 1, 2} represent {-1, 0, +1}.
660+
* Dequantization: output[i] = (ternary[i] - 1) * scale
661+
*/
662+
internal fun dequantTQ1_0(raw: List<Any>, nElems: Int): FloatArray {
663+
val bytes = toByteArray(raw, "TQ1_0")
664+
val blockSize = 256
665+
val bytesPerBlock = 54 // 48 (base-3) + 4 (2-bit) + 2 (f16 scale)
666+
val blockCount = bytes.size / bytesPerBlock
667+
val out = FloatArray(blockCount * blockSize)
668+
var offset = 0
669+
var outOff = 0
670+
671+
repeat(blockCount) {
672+
// Read base-3 packed data (48 bytes = 240 elements)
673+
val qsBase3 = bytes.copyOfRange(offset, offset + 48)
674+
offset += 48
675+
676+
// Read 2-bit packed data for remaining 16 elements (4 bytes)
677+
val qs2bit = bytes.copyOfRange(offset, offset + 4)
678+
offset += 4
679+
680+
// Read f16 scale
681+
val scale = halfToFloat(
682+
(bytes[offset + 1].toInt() and 0xFF shl 8) or (bytes[offset].toInt() and 0xFF)
683+
)
684+
offset += 2
685+
686+
// Decode base-3 packed values (5 values per byte)
687+
// Each byte b encodes: v0 + v1*3 + v2*9 + v3*27 + v4*81
688+
var outIdx = 0
689+
for (i in 0 until 48) {
690+
var b = qsBase3[i].toInt() and 0xFF
691+
repeat(5) {
692+
val v = (b % 3) - 1 // Extract value and convert to {-1, 0, +1}
693+
out[outOff + outIdx] = v * scale
694+
outIdx++
695+
b /= 3
696+
}
697+
}
698+
699+
// Decode remaining 16 elements from 2-bit packing (4 bytes)
700+
for (i in 0 until 4) {
701+
val b = qs2bit[i].toInt() and 0xFF
702+
val v0 = (b and 0x03) - 1
703+
val v1 = ((b shr 2) and 0x03) - 1
704+
val v2 = ((b shr 4) and 0x03) - 1
705+
val v3 = ((b shr 6) and 0x03) - 1
706+
707+
out[outOff + 240 + i * 4 + 0] = v0 * scale
708+
out[outOff + 240 + i * 4 + 1] = v1 * scale
709+
out[outOff + 240 + i * 4 + 2] = v2 * scale
710+
out[outOff + 240 + i * 4 + 3] = v3 * scale
711+
}
712+
713+
outOff += blockSize
714+
}
715+
return out
716+
}
600717
}
601718

602719
/**
@@ -922,7 +1039,9 @@ public class LlamaWeightLoader(
9221039
GGMLQuantizationType.Q6_K,
9231040
GGMLQuantizationType.Q8_K,
9241041
GGMLQuantizationType.IQ4_NL,
925-
GGMLQuantizationType.IQ4_XS -> {
1042+
GGMLQuantizationType.IQ4_XS,
1043+
GGMLQuantizationType.TQ1_0,
1044+
GGMLQuantizationType.TQ2_0 -> {
9261045
when (quantPolicy) {
9271046
QuantPolicy.RAW_BYTES -> {
9281047
require(dtype == Int8::class) {
@@ -954,6 +1073,8 @@ public class LlamaWeightLoader(
9541073
GGMLQuantizationType.Q8_K -> dequantQ8K(raw, rt.nElements)
9551074
GGMLQuantizationType.IQ4_NL -> dequantIQ4NL(raw, rt.nElements)
9561075
GGMLQuantizationType.IQ4_XS -> dequantIQ4XS(raw, rt.nElements)
1076+
GGMLQuantizationType.TQ1_0 -> dequantTQ1_0(raw, rt.nElements)
1077+
GGMLQuantizationType.TQ2_0 -> dequantTQ2_0(raw, rt.nElements)
9571078
else -> error("Dequantization for ${rt.tensorType} not implemented yet")
9581079
}
9591080
@Suppress("UNCHECKED_CAST")

skainet-io/skainet-io-gguf/src/jvmTest/kotlin/sk/ainet/io/gguf/llama/LlamaQuantDequantTest.kt

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,124 @@ class LlamaQuantDequantTest {
190190
}
191191
assertContentEquals(expected.toList(), out.toList())
192192
}
193+
194+
@Test
195+
fun `dequant TQ2_0 block with scale 1 and all zeros yields minus ones`() {
196+
// TQ2_0: 66 bytes = 64 data + 2 f16 scale
197+
// All data bytes = 0x00 -> each 2-bit value is 0 -> (0-1) = -1
198+
// Scale = 1.0 (0x3C00)
199+
val raw = ByteArray(66) { 0x00 }
200+
raw[64] = 0x00 // scale low byte
201+
raw[65] = 0x3C // scale high byte (f16 1.0)
202+
val out = LlamaWeightLoader.dequantTQ2_0(raw.toList(), 256)
203+
assertContentEquals(FloatArray(256) { -1f }.toList(), out.toList())
204+
}
205+
206+
@Test
207+
fun `dequant TQ2_0 block with all ones yields zeros`() {
208+
// All data bytes = 0x55 -> each 2-bit value is 1 (01 01 01 01) -> (1-1) = 0
209+
// Scale = 1.0
210+
val raw = ByteArray(66) { 0x55 }
211+
raw[64] = 0x00; raw[65] = 0x3C
212+
val out = LlamaWeightLoader.dequantTQ2_0(raw.toList(), 256)
213+
assertContentEquals(FloatArray(256) { 0f }.toList(), out.toList())
214+
}
215+
216+
@Test
217+
fun `dequant TQ2_0 block with all twos yields plus ones`() {
218+
// All data bytes = 0xAA -> each 2-bit value is 2 (10 10 10 10) -> (2-1) = +1
219+
// Scale = 1.0
220+
val raw = ByteArray(66) { 0xAA.toByte() }
221+
raw[64] = 0x00; raw[65] = 0x3C
222+
val out = LlamaWeightLoader.dequantTQ2_0(raw.toList(), 256)
223+
assertContentEquals(FloatArray(256) { 1f }.toList(), out.toList())
224+
}
225+
226+
@Test
227+
fun `dequant TQ2_0 block applies scale correctly`() {
228+
// All twos (+1) with scale = 2.0 (0x4000)
229+
val raw = ByteArray(66) { 0xAA.toByte() }
230+
raw[64] = 0x00; raw[65] = 0x40 // f16 2.0
231+
val out = LlamaWeightLoader.dequantTQ2_0(raw.toList(), 256)
232+
assertContentEquals(FloatArray(256) { 2f }.toList(), out.toList())
233+
}
234+
235+
@Test
236+
fun `dequant TQ2_0 block with mixed values`() {
237+
// First byte = 0xE4 = 11 10 01 00 in binary
238+
// Values: v0=0 (-1), v1=1 (0), v2=2 (+1), v3=3 -> but 3 is invalid, should be clamped to +2
239+
// Actually TQ2_0 only uses values 0,1,2. If we see 3, (3-1)=2
240+
val raw = ByteArray(66) { 0x55 } // default to zeros
241+
raw[0] = 0xE4.toByte() // 11_10_01_00: v0=-1, v1=0, v2=+1, v3=+2 (if 3 is allowed)
242+
raw[64] = 0x00; raw[65] = 0x3C // scale = 1.0
243+
val out = LlamaWeightLoader.dequantTQ2_0(raw.toList(), 256)
244+
// First 4 elements: (0-1)=-1, (1-1)=0, (2-1)=+1, (3-1)=+2
245+
kotlin.test.assertEquals(-1f, out[0], 0.001f)
246+
kotlin.test.assertEquals(0f, out[1], 0.001f)
247+
kotlin.test.assertEquals(1f, out[2], 0.001f)
248+
kotlin.test.assertEquals(2f, out[3], 0.001f) // 3 encodes as +2 when scaled
249+
}
250+
251+
@Test
252+
fun `dequant TQ1_0 block with all zeros yields minus ones`() {
253+
// TQ1_0: 54 bytes = 48 base-3 + 4 2-bit + 2 f16 scale
254+
// All base-3 bytes = 0 means each decoded value is 0 -> (0-1) = -1
255+
// All 2-bit bytes = 0 means remaining 16 values are also -1
256+
val raw = ByteArray(54) { 0x00 }
257+
raw[52] = 0x00; raw[53] = 0x3C // scale = 1.0
258+
val out = LlamaWeightLoader.dequantTQ1_0(raw.toList(), 256)
259+
assertContentEquals(FloatArray(256) { -1f }.toList(), out.toList())
260+
}
261+
262+
@Test
263+
fun `dequant TQ1_0 block with base3 ones yields zeros`() {
264+
// Base-3 encoding: each byte encodes 5 values as v0 + v1*3 + v2*9 + v3*27 + v4*81
265+
// For all ones: 1 + 3 + 9 + 27 + 81 = 121 (0x79)
266+
// 2-bit packed: 0x55 = 01 01 01 01 = all ones
267+
val raw = ByteArray(54) { 0x00 }
268+
repeat(48) { raw[it] = 0x79 } // base-3 all ones
269+
repeat(4) { raw[48 + it] = 0x55 } // 2-bit all ones
270+
raw[52] = 0x00; raw[53] = 0x3C // scale = 1.0
271+
val out = LlamaWeightLoader.dequantTQ1_0(raw.toList(), 256)
272+
assertContentEquals(FloatArray(256) { 0f }.toList(), out.toList())
273+
}
274+
275+
@Test
276+
fun `dequant TQ1_0 block with base3 twos yields plus ones`() {
277+
// For all twos: 2 + 6 + 18 + 54 + 162 = 242 (0xF2)
278+
// 2-bit packed: 0xAA = 10 10 10 10 = all twos
279+
val raw = ByteArray(54) { 0x00 }
280+
repeat(48) { raw[it] = 0xF2.toByte() } // base-3 all twos
281+
repeat(4) { raw[48 + it] = 0xAA.toByte() } // 2-bit all twos
282+
raw[52] = 0x00; raw[53] = 0x3C // scale = 1.0
283+
val out = LlamaWeightLoader.dequantTQ1_0(raw.toList(), 256)
284+
assertContentEquals(FloatArray(256) { 1f }.toList(), out.toList())
285+
}
286+
287+
@Test
288+
fun `dequant TQ1_0 block applies scale correctly`() {
289+
// All twos with scale = 2.0
290+
val raw = ByteArray(54) { 0x00 }
291+
repeat(48) { raw[it] = 0xF2.toByte() } // base-3 all twos
292+
repeat(4) { raw[48 + it] = 0xAA.toByte() } // 2-bit all twos
293+
raw[52] = 0x00; raw[53] = 0x40 // scale = 2.0
294+
val out = LlamaWeightLoader.dequantTQ1_0(raw.toList(), 256)
295+
assertContentEquals(FloatArray(256) { 2f }.toList(), out.toList())
296+
}
297+
298+
@Test
299+
fun `dequant TQ1_0 base3 decoding for mixed values`() {
300+
// Test decoding first 5 values from one base-3 byte
301+
// Values: 0, 1, 2, 0, 1 -> 0 + 1*3 + 2*9 + 0*27 + 1*81 = 3 + 18 + 81 = 102 (0x66)
302+
val raw = ByteArray(54) { 0x79 } // default all ones
303+
raw[0] = 0x66 // first 5 values: -1, 0, +1, -1, 0
304+
repeat(4) { raw[48 + it] = 0x55 } // 2-bit all ones
305+
raw[52] = 0x00; raw[53] = 0x3C // scale = 1.0
306+
val out = LlamaWeightLoader.dequantTQ1_0(raw.toList(), 256)
307+
kotlin.test.assertEquals(-1f, out[0], 0.001f)
308+
kotlin.test.assertEquals(0f, out[1], 0.001f)
309+
kotlin.test.assertEquals(1f, out[2], 0.001f)
310+
kotlin.test.assertEquals(-1f, out[3], 0.001f)
311+
kotlin.test.assertEquals(0f, out[4], 0.001f)
312+
}
193313
}

0 commit comments

Comments
 (0)