Skip to content

Commit 2bae056

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feature/gemma4
# Conflicts: # gradle/libs.versions.toml
2 parents 6d022d8 + 3c19515 commit 2bae056

5 files changed

Lines changed: 193 additions & 14 deletions

File tree

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
antora-playbook.yml
4848
4949
- name: Upload artifact
50-
uses: actions/upload-pages-artifact@v3
50+
uses: actions/upload-pages-artifact@v5
5151
with:
5252
path: docs/build/site
5353

gradle/libs.versions.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
skainet = "0.20.0"
33
agp = "9.2.0"
44
jacksonDatabind = "2.21.2"
5-
jsonSchemaValidator = "3.0.1"
5+
jsonSchemaValidator = "3.0.2"
66
jsonSchemaValidatorVersion = "0.5.4"
77
junit = "4.13.2"
88
junitJupiter = "6.0.3"
9-
kotlin = "2.3.20"
9+
kotlin = "2.3.21"
1010
kotlinxCoroutines = "1.10.2"
1111
kotlinBrowser = "0.5.0"
1212
android-minSdk = "24"
1313
android-compileSdk = "36"
1414
kotlinxSerializationJson = "1.11.0"
15-
ktorClientCore = "3.4.2"
15+
ktorClientCore = "3.4.3"
1616
ktorClientPlugins = "3.1.1"
1717
logbackClassic = "1.5.32"
1818
kover = "0.9.8"

llm-core/src/commonMain/kotlin/sk/ainet/apps/llm/RopeUtils.kt

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,46 @@ import kotlin.math.cos
44
import kotlin.math.pow
55
import kotlin.math.sin
66

7+
/**
8+
* RoPE rotation conventions used by GGUF-based models.
9+
*
10+
* The two conventions are mathematically equivalent under different weight
11+
* permutations and produce identical results IF the weights are stored in the
12+
* matching layout. Mismatching them silently corrupts attention.
13+
*
14+
* | Convention | llama.cpp name | Pair indexing | Used by |
15+
* |--------------|---------------------|------------------------------------|--------------------------|
16+
* | [INTERLEAVED]| `LLAMA_ROPE_TYPE_NORM` (mode 0) | `(buf[2i], buf[2i+1])` | LLaMA, Mistral, Gemma |
17+
* | [HALF_SPLIT] | `LLAMA_ROPE_TYPE_NEOX` (mode 2) | `(buf[i], buf[i+ropeDim/2])` | Qwen 2/3, Phi, Falcon |
18+
*
19+
* llama.cpp picks the right convention per architecture via
20+
* `llm_arch_rope_type(arch)`. We mirror that mapping in [CpuAttentionBackend]
21+
* (and any other backend) so that GGUF tensors load as-is — no weight
22+
* permutation at conversion time.
23+
*/
24+
public enum class RopeType {
25+
/** Interleaved adjacent-pair rotation (llama.cpp NORM, mode 0). */
26+
INTERLEAVED,
27+
/** Half-split rotation (llama.cpp NEOX, mode 2) — first half rotates with second half. */
28+
HALF_SPLIT;
29+
30+
public companion object {
31+
/**
32+
* Map a GGUF `general.architecture` string to the RoPE convention the
33+
* model was trained / converted with. Mirrors `llm_arch_rope_type()` in
34+
* `llama.cpp/src/llama-arch.cpp`.
35+
*
36+
* Defaults to [INTERLEAVED] for unknown architectures (the LLaMA-family
37+
* default), since most new families that need [HALF_SPLIT] derive from
38+
* Qwen / Phi / Falcon and should be added here explicitly.
39+
*/
40+
public fun forArchitecture(arch: String): RopeType = when (arch.lowercase()) {
41+
"qwen2", "qwen3", "qwen35", "phi2", "phi3", "phi4", "falcon", "mpt", "stablelm", "starcoder2" -> HALF_SPLIT
42+
else -> INTERLEAVED
43+
}
44+
}
45+
}
46+
747
/**
848
* Compute RoPE (Rotary Position Embedding) frequency for a given pair index and position.
949
*
@@ -63,15 +103,16 @@ public fun applyRopeRotation(
63103
precomputedCos: FloatArray? = null,
64104
precomputedSin: FloatArray? = null,
65105
ropeStride: Int = ropeDim / 2,
66-
precomputedMatchBase: Float? = null
106+
precomputedMatchBase: Float? = null,
107+
ropeType: RopeType = RopeType.INTERLEAVED
67108
) {
68109
val usePrecomputed = precomputedCos != null && precomputedSin != null &&
69110
(precomputedMatchBase == null || base == precomputedMatchBase)
111+
val halfDim = ropeDim / 2
70112

71113
for (h in 0 until nHeads) {
72114
val headOffset = h * headSize
73-
for (pair in 0 until ropeDim / 2) {
74-
val i = pair * 2
115+
for (pair in 0 until halfDim) {
75116
val fcr: Float
76117
val fci: Float
77118
if (usePrecomputed) {
@@ -81,10 +122,16 @@ public fun applyRopeRotation(
81122
fcr = ropeCos(pair, pos, ropeDim, base)
82123
fci = ropeSin(pair, pos, ropeDim, base)
83124
}
84-
val v0 = buf[headOffset + i]
85-
val v1 = buf[headOffset + i + 1]
86-
buf[headOffset + i] = v0 * fcr - v1 * fci
87-
buf[headOffset + i + 1] = v0 * fci + v1 * fcr
125+
// INTERLEAVED rotates (2i, 2i+1) — adjacent pairs (Llama / Gemma / Mistral).
126+
// HALF_SPLIT rotates (i, i + ropeDim/2) — first-half / second-half (Qwen / Phi / Falcon).
127+
val (idxA, idxB) = when (ropeType) {
128+
RopeType.INTERLEAVED -> headOffset + pair * 2 to headOffset + pair * 2 + 1
129+
RopeType.HALF_SPLIT -> headOffset + pair to headOffset + pair + halfDim
130+
}
131+
val v0 = buf[idxA]
132+
val v1 = buf[idxB]
133+
buf[idxA] = v0 * fcr - v1 * fci
134+
buf[idxB] = v0 * fci + v1 * fcr
88135
}
89136
}
90137
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package sk.ainet.apps.llm
2+
3+
import kotlin.math.abs
4+
import kotlin.math.cos
5+
import kotlin.math.sin
6+
import kotlin.test.Test
7+
import kotlin.test.assertEquals
8+
import kotlin.test.assertTrue
9+
10+
/**
11+
* RoPE rotation tests covering both conventions emitted by GGUF tooling.
12+
*
13+
* Llama family is INTERLEAVED (`(buf[2i], buf[2i+1])`); Qwen 2/3 / Phi /
14+
* Falcon are HALF_SPLIT (`(buf[i], buf[i+ropeDim/2])`). Mismatching the
15+
* convention vs. how the model was trained silently corrupts attention —
16+
* see ISSUE-74 (Qwen3 runtime degenerate output).
17+
*/
18+
class RopeUtilsTest {
19+
20+
private fun assertCloseTo(expected: Float, actual: Float, tol: Float = 1e-5f) {
21+
assertTrue(abs(expected - actual) < tol,
22+
"expected $expected, got $actual (delta ${abs(expected - actual)})")
23+
}
24+
25+
@Test
26+
fun interleavedRotationMatchesAdjacentPairFormula() {
27+
// Single head, headSize = ropeDim = 4 → 2 pairs at indices (0,1) and (2,3).
28+
// pos = 1, base = 10000.
29+
val buf = floatArrayOf(1f, 2f, 3f, 4f)
30+
val pos = 1
31+
val ropeDim = 4
32+
val base = 10000f
33+
applyRopeRotation(buf, nHeads = 1, headSize = ropeDim, ropeDim = ropeDim, pos = pos, base = base)
34+
35+
// Reference: rotate (b[0], b[1]) by freq(pair=0), (b[2], b[3]) by freq(pair=1).
36+
val f0 = pos / base.toDouble().pow(0.0).toFloat()
37+
val f1 = pos / base.toDouble().pow(0.5).toFloat() // 2*1/4 = 0.5
38+
val expected = floatArrayOf(
39+
1f * cos(f0) - 2f * sin(f0),
40+
1f * sin(f0) + 2f * cos(f0),
41+
3f * cos(f1) - 4f * sin(f1),
42+
3f * sin(f1) + 4f * cos(f1)
43+
)
44+
for (i in buf.indices) assertCloseTo(expected[i], buf[i])
45+
}
46+
47+
@Test
48+
fun halfSplitRotationPairsFirstHalfWithSecondHalf() {
49+
// Single head, headSize = ropeDim = 4 → pairs are (b[0], b[2]) and (b[1], b[3]).
50+
// Distinct from interleaved which would pair (b[0], b[1]) and (b[2], b[3]).
51+
val buf = floatArrayOf(1f, 2f, 3f, 4f)
52+
val pos = 1
53+
val ropeDim = 4
54+
val base = 10000f
55+
applyRopeRotation(
56+
buf, nHeads = 1, headSize = ropeDim, ropeDim = ropeDim,
57+
pos = pos, base = base, ropeType = RopeType.HALF_SPLIT
58+
)
59+
60+
val f0 = pos / base.toDouble().pow(0.0).toFloat()
61+
val f1 = pos / base.toDouble().pow(0.5).toFloat()
62+
// Pair 0: (b[0]=1, b[2]=3) rotated by f0 → goes back to (b[0], b[2])
63+
// Pair 1: (b[1]=2, b[3]=4) rotated by f1 → goes back to (b[1], b[3])
64+
val expected = floatArrayOf(
65+
1f * cos(f0) - 3f * sin(f0), // b[0]
66+
2f * cos(f1) - 4f * sin(f1), // b[1]
67+
1f * sin(f0) + 3f * cos(f0), // b[2]
68+
2f * sin(f1) + 4f * cos(f1) // b[3]
69+
)
70+
for (i in buf.indices) assertCloseTo(expected[i], buf[i])
71+
}
72+
73+
@Test
74+
fun halfSplitDoesNotEqualInterleavedForSameInput() {
75+
// Regression guard: two conventions must produce *different* outputs on
76+
// the same input (so a wiring mistake is observable).
77+
val a = floatArrayOf(1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f)
78+
val b = a.copyOf()
79+
val pos = 3
80+
applyRopeRotation(a, nHeads = 1, headSize = 8, ropeDim = 8, pos = pos, base = 10000f, ropeType = RopeType.INTERLEAVED)
81+
applyRopeRotation(b, nHeads = 1, headSize = 8, ropeDim = 8, pos = pos, base = 10000f, ropeType = RopeType.HALF_SPLIT)
82+
83+
var anyDiff = false
84+
for (i in a.indices) if (abs(a[i] - b[i]) > 1e-4f) anyDiff = true
85+
assertTrue(anyDiff, "INTERLEAVED and HALF_SPLIT must produce different rotations: a=${a.toList()} b=${b.toList()}")
86+
}
87+
88+
@Test
89+
fun defaultRopeTypeIsInterleavedForBackwardsCompat() {
90+
// Callers that don't pass ropeType (existing Llama path) must keep the
91+
// pre-fix behavior, i.e. interleaved rotation.
92+
val withDefault = floatArrayOf(0.5f, 1.5f, 2.5f, 3.5f)
93+
val explicitInterleaved = withDefault.copyOf()
94+
applyRopeRotation(withDefault, nHeads = 1, headSize = 4, ropeDim = 4, pos = 2, base = 10000f)
95+
applyRopeRotation(
96+
explicitInterleaved, nHeads = 1, headSize = 4, ropeDim = 4,
97+
pos = 2, base = 10000f, ropeType = RopeType.INTERLEAVED
98+
)
99+
for (i in withDefault.indices) assertEquals(explicitInterleaved[i], withDefault[i])
100+
}
101+
102+
@Test
103+
fun halfSplitMultipleHeadsAreIndependent() {
104+
// Each head is rotated independently — the half-split partitioning is
105+
// within the head, not across heads.
106+
val buf = floatArrayOf(
107+
1f, 2f, 3f, 4f, // head 0
108+
5f, 6f, 7f, 8f // head 1
109+
)
110+
val pos = 2
111+
applyRopeRotation(
112+
buf, nHeads = 2, headSize = 4, ropeDim = 4,
113+
pos = pos, base = 10000f, ropeType = RopeType.HALF_SPLIT
114+
)
115+
val f0 = pos / 1f
116+
val f1 = pos / 100f
117+
// Head 0: (1,3) rotated by f0, (2,4) rotated by f1
118+
assertCloseTo(1f * cos(f0) - 3f * sin(f0), buf[0])
119+
assertCloseTo(2f * cos(f1) - 4f * sin(f1), buf[1])
120+
assertCloseTo(1f * sin(f0) + 3f * cos(f0), buf[2])
121+
assertCloseTo(2f * sin(f1) + 4f * cos(f1), buf[3])
122+
// Head 1: (5,7) rotated by f0, (6,8) rotated by f1 — same freqs, different inputs
123+
assertCloseTo(5f * cos(f0) - 7f * sin(f0), buf[4])
124+
assertCloseTo(6f * cos(f1) - 8f * sin(f1), buf[5])
125+
assertCloseTo(5f * sin(f0) + 7f * cos(f0), buf[6])
126+
assertCloseTo(6f * sin(f1) + 8f * cos(f1), buf[7])
127+
}
128+
}
129+
130+
private fun Double.pow(exp: Double): Double = kotlin.math.exp(exp * kotlin.math.ln(this))

llm-runtime/kllama/src/commonMain/kotlin/sk/ainet/apps/kllama/CpuAttentionBackend.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sk.ainet.apps.kllama
33
import kotlin.math.sqrt
44
import sk.ainet.apps.llm.KvCache
55
import sk.ainet.apps.llm.HeapKvCache
6+
import sk.ainet.apps.llm.RopeType
67
import sk.ainet.apps.llm.applyRopeRotation
78
import sk.ainet.apps.llm.softmaxInPlace
89
import sk.ainet.context.ExecutionContext
@@ -31,7 +32,8 @@ public class CpuAttentionBackend<T : DType>(
3132
private val dtype: KClass<T>,
3233
kvCache: KvCache? = null,
3334
private val ropeFreqBase: Float = 10000f,
34-
maxContextLength: Int? = null
35+
maxContextLength: Int? = null,
36+
private val ropeType: RopeType = RopeType.forArchitecture(weights.metadata.architecture)
3537
) : AttentionBackend<T> {
3638

3739
private val dim = weights.metadata.embeddingLength
@@ -113,8 +115,8 @@ public class CpuAttentionBackend<T : DType>(
113115

114116
require(headSize % 2 == 0) { "RoPE requires even head size; got $headSize" }
115117

116-
applyRopeRotation(qBuf, nHeads, headSize, ropeDim, pos, ropeFreqBase, ropeReal, ropeImag, ropeStride)
117-
applyRopeRotation(kBuf, nKvHeads, headSize, ropeDim, pos, ropeFreqBase, ropeReal, ropeImag, ropeStride)
118+
applyRopeRotation(qBuf, nHeads, headSize, ropeDim, pos, ropeFreqBase, ropeReal, ropeImag, ropeStride, ropeType = ropeType)
119+
applyRopeRotation(kBuf, nKvHeads, headSize, ropeDim, pos, ropeFreqBase, ropeReal, ropeImag, ropeStride, ropeType = ropeType)
118120
}
119121

120122
private fun attentionGqa(layerIdx: Int, qBuf: FloatArray, pos: Int): FloatArray {

0 commit comments

Comments
 (0)