|
| 1 | +package sk.ainet.apps.llm |
| 2 | + |
| 3 | +import kotlin.math.abs |
| 4 | +import kotlin.math.cos |
| 5 | +import kotlin.math.sin |
| 6 | +import kotlin.test.Test |
| 7 | +import kotlin.test.assertEquals |
| 8 | +import kotlin.test.assertTrue |
| 9 | + |
| 10 | +/** |
| 11 | + * RoPE rotation tests covering both conventions emitted by GGUF tooling. |
| 12 | + * |
| 13 | + * Llama family is INTERLEAVED (`(buf[2i], buf[2i+1])`); Qwen 2/3 / Phi / |
| 14 | + * Falcon are HALF_SPLIT (`(buf[i], buf[i+ropeDim/2])`). Mismatching the |
| 15 | + * convention vs. how the model was trained silently corrupts attention — |
| 16 | + * see ISSUE-74 (Qwen3 runtime degenerate output). |
| 17 | + */ |
| 18 | +class RopeUtilsTest { |
| 19 | + |
| 20 | + private fun assertCloseTo(expected: Float, actual: Float, tol: Float = 1e-5f) { |
| 21 | + assertTrue(abs(expected - actual) < tol, |
| 22 | + "expected $expected, got $actual (delta ${abs(expected - actual)})") |
| 23 | + } |
| 24 | + |
| 25 | + @Test |
| 26 | + fun interleavedRotationMatchesAdjacentPairFormula() { |
| 27 | + // Single head, headSize = ropeDim = 4 → 2 pairs at indices (0,1) and (2,3). |
| 28 | + // pos = 1, base = 10000. |
| 29 | + val buf = floatArrayOf(1f, 2f, 3f, 4f) |
| 30 | + val pos = 1 |
| 31 | + val ropeDim = 4 |
| 32 | + val base = 10000f |
| 33 | + applyRopeRotation(buf, nHeads = 1, headSize = ropeDim, ropeDim = ropeDim, pos = pos, base = base) |
| 34 | + |
| 35 | + // Reference: rotate (b[0], b[1]) by freq(pair=0), (b[2], b[3]) by freq(pair=1). |
| 36 | + val f0 = pos / base.toDouble().pow(0.0).toFloat() |
| 37 | + val f1 = pos / base.toDouble().pow(0.5).toFloat() // 2*1/4 = 0.5 |
| 38 | + val expected = floatArrayOf( |
| 39 | + 1f * cos(f0) - 2f * sin(f0), |
| 40 | + 1f * sin(f0) + 2f * cos(f0), |
| 41 | + 3f * cos(f1) - 4f * sin(f1), |
| 42 | + 3f * sin(f1) + 4f * cos(f1) |
| 43 | + ) |
| 44 | + for (i in buf.indices) assertCloseTo(expected[i], buf[i]) |
| 45 | + } |
| 46 | + |
| 47 | + @Test |
| 48 | + fun halfSplitRotationPairsFirstHalfWithSecondHalf() { |
| 49 | + // Single head, headSize = ropeDim = 4 → pairs are (b[0], b[2]) and (b[1], b[3]). |
| 50 | + // Distinct from interleaved which would pair (b[0], b[1]) and (b[2], b[3]). |
| 51 | + val buf = floatArrayOf(1f, 2f, 3f, 4f) |
| 52 | + val pos = 1 |
| 53 | + val ropeDim = 4 |
| 54 | + val base = 10000f |
| 55 | + applyRopeRotation( |
| 56 | + buf, nHeads = 1, headSize = ropeDim, ropeDim = ropeDim, |
| 57 | + pos = pos, base = base, ropeType = RopeType.HALF_SPLIT |
| 58 | + ) |
| 59 | + |
| 60 | + val f0 = pos / base.toDouble().pow(0.0).toFloat() |
| 61 | + val f1 = pos / base.toDouble().pow(0.5).toFloat() |
| 62 | + // Pair 0: (b[0]=1, b[2]=3) rotated by f0 → goes back to (b[0], b[2]) |
| 63 | + // Pair 1: (b[1]=2, b[3]=4) rotated by f1 → goes back to (b[1], b[3]) |
| 64 | + val expected = floatArrayOf( |
| 65 | + 1f * cos(f0) - 3f * sin(f0), // b[0] |
| 66 | + 2f * cos(f1) - 4f * sin(f1), // b[1] |
| 67 | + 1f * sin(f0) + 3f * cos(f0), // b[2] |
| 68 | + 2f * sin(f1) + 4f * cos(f1) // b[3] |
| 69 | + ) |
| 70 | + for (i in buf.indices) assertCloseTo(expected[i], buf[i]) |
| 71 | + } |
| 72 | + |
| 73 | + @Test |
| 74 | + fun halfSplitDoesNotEqualInterleavedForSameInput() { |
| 75 | + // Regression guard: two conventions must produce *different* outputs on |
| 76 | + // the same input (so a wiring mistake is observable). |
| 77 | + val a = floatArrayOf(1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f) |
| 78 | + val b = a.copyOf() |
| 79 | + val pos = 3 |
| 80 | + applyRopeRotation(a, nHeads = 1, headSize = 8, ropeDim = 8, pos = pos, base = 10000f, ropeType = RopeType.INTERLEAVED) |
| 81 | + applyRopeRotation(b, nHeads = 1, headSize = 8, ropeDim = 8, pos = pos, base = 10000f, ropeType = RopeType.HALF_SPLIT) |
| 82 | + |
| 83 | + var anyDiff = false |
| 84 | + for (i in a.indices) if (abs(a[i] - b[i]) > 1e-4f) anyDiff = true |
| 85 | + assertTrue(anyDiff, "INTERLEAVED and HALF_SPLIT must produce different rotations: a=${a.toList()} b=${b.toList()}") |
| 86 | + } |
| 87 | + |
| 88 | + @Test |
| 89 | + fun defaultRopeTypeIsInterleavedForBackwardsCompat() { |
| 90 | + // Callers that don't pass ropeType (existing Llama path) must keep the |
| 91 | + // pre-fix behavior, i.e. interleaved rotation. |
| 92 | + val withDefault = floatArrayOf(0.5f, 1.5f, 2.5f, 3.5f) |
| 93 | + val explicitInterleaved = withDefault.copyOf() |
| 94 | + applyRopeRotation(withDefault, nHeads = 1, headSize = 4, ropeDim = 4, pos = 2, base = 10000f) |
| 95 | + applyRopeRotation( |
| 96 | + explicitInterleaved, nHeads = 1, headSize = 4, ropeDim = 4, |
| 97 | + pos = 2, base = 10000f, ropeType = RopeType.INTERLEAVED |
| 98 | + ) |
| 99 | + for (i in withDefault.indices) assertEquals(explicitInterleaved[i], withDefault[i]) |
| 100 | + } |
| 101 | + |
| 102 | + @Test |
| 103 | + fun halfSplitMultipleHeadsAreIndependent() { |
| 104 | + // Each head is rotated independently — the half-split partitioning is |
| 105 | + // within the head, not across heads. |
| 106 | + val buf = floatArrayOf( |
| 107 | + 1f, 2f, 3f, 4f, // head 0 |
| 108 | + 5f, 6f, 7f, 8f // head 1 |
| 109 | + ) |
| 110 | + val pos = 2 |
| 111 | + applyRopeRotation( |
| 112 | + buf, nHeads = 2, headSize = 4, ropeDim = 4, |
| 113 | + pos = pos, base = 10000f, ropeType = RopeType.HALF_SPLIT |
| 114 | + ) |
| 115 | + val f0 = pos / 1f |
| 116 | + val f1 = pos / 100f |
| 117 | + // Head 0: (1,3) rotated by f0, (2,4) rotated by f1 |
| 118 | + assertCloseTo(1f * cos(f0) - 3f * sin(f0), buf[0]) |
| 119 | + assertCloseTo(2f * cos(f1) - 4f * sin(f1), buf[1]) |
| 120 | + assertCloseTo(1f * sin(f0) + 3f * cos(f0), buf[2]) |
| 121 | + assertCloseTo(2f * sin(f1) + 4f * cos(f1), buf[3]) |
| 122 | + // Head 1: (5,7) rotated by f0, (6,8) rotated by f1 — same freqs, different inputs |
| 123 | + assertCloseTo(5f * cos(f0) - 7f * sin(f0), buf[4]) |
| 124 | + assertCloseTo(6f * cos(f1) - 8f * sin(f1), buf[5]) |
| 125 | + assertCloseTo(5f * sin(f0) + 7f * cos(f0), buf[6]) |
| 126 | + assertCloseTo(6f * sin(f1) + 8f * cos(f1), buf[7]) |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +private fun Double.pow(exp: Double): Double = kotlin.math.exp(exp * kotlin.math.ln(this)) |
0 commit comments