|
| 1 | +package sk.ainet.models.gemma |
| 2 | + |
| 3 | +import sk.ainet.context.ExecutionContext |
| 4 | +import sk.ainet.lang.tensor.Tensor |
| 5 | +import sk.ainet.lang.tensor.matmul |
| 6 | +import sk.ainet.lang.tensor.plus |
| 7 | +import sk.ainet.lang.tensor.times |
| 8 | +import sk.ainet.lang.tensor.t |
| 9 | +import sk.ainet.lang.types.DType |
| 10 | +import kotlin.reflect.KClass |
| 11 | + |
| 12 | +/** |
| 13 | + * Global AltUp weights shared across all layers. |
| 14 | + * |
| 15 | + * @param projWeight Projects embedding into (numInputs-1) additional states [hiddenSize, hiddenSize, numInputs-1] |
| 16 | + * @param unembdProjWeight Projects back for output combination [hiddenSize, hiddenSize, numInputs-1] |
| 17 | + */ |
| 18 | +public data class AltUpGlobalWeights<T : DType>( |
| 19 | + val projWeight: Tensor<T, Float>, |
| 20 | + val unembdProjWeight: Tensor<T, Float> |
| 21 | +) |
| 22 | + |
| 23 | +/** |
| 24 | + * Per-layer AltUp weights. |
| 25 | + * |
| 26 | + * @param predictCoef Prediction coefficients [numInputs, numInputs * numInputs] |
| 27 | + * @param correctCoef Correction coefficients [numInputs, numInputs] |
| 28 | + * @param correctScale Per-element scaling for correction [hiddenSize] |
| 29 | + * @param routerWeight Router projection [hiddenSize, numInputs] |
| 30 | + * @param routerNorm Router normalization [hiddenSize] |
| 31 | + */ |
| 32 | +public data class AltUpLayerWeights<T : DType>( |
| 33 | + val predictCoef: Tensor<T, Float>, |
| 34 | + val correctCoef: Tensor<T, Float>, |
| 35 | + val correctScale: Tensor<T, Float>, |
| 36 | + val routerWeight: Tensor<T, Float>, |
| 37 | + val routerNorm: Tensor<T, Float> |
| 38 | +) |
| 39 | + |
| 40 | +/** |
| 41 | + * AltUp (Alternating Updates) implementation for Gemma 3n E4B. |
| 42 | + * |
| 43 | + * AltUp maintains multiple parallel hidden states (E4B: 4) but only routes |
| 44 | + * the "active" state through expensive transformer layers. The other states |
| 45 | + * are cheaply predicted/corrected using learned per-layer coefficients. |
| 46 | + * |
| 47 | + * Architecture (from GGUF inspection): |
| 48 | + * - Global: altup_proj [2048,2048,3] creates 3 extra states from embedding |
| 49 | + * - Per-layer: router projects hidden to routing logits, predict_coef/correct_coef |
| 50 | + * control state updates, correct_scale modulates corrections element-wise |
| 51 | + * - Global: altup_unembd_proj [2048,2048,3] recombines states for output |
| 52 | + * |
| 53 | + * @param ctx ExecutionContext for tensor operations |
| 54 | + * @param dtype Data type class |
| 55 | + * @param numInputs Number of parallel inputs (E4B: 4) |
| 56 | + * @param activeIdx Index of the active input (0) |
| 57 | + * @param hiddenSize Model hidden dimension |
| 58 | + * @param globalWeights Global projection/unprojection weights |
| 59 | + * @param layerWeights Per-layer AltUp weights |
| 60 | + */ |
| 61 | +public class AltUp<T : DType>( |
| 62 | + private val ctx: ExecutionContext, |
| 63 | + private val dtype: KClass<T>, |
| 64 | + private val numInputs: Int, |
| 65 | + public val activeIdx: Int, |
| 66 | + private val hiddenSize: Int, |
| 67 | + private val globalWeights: AltUpGlobalWeights<T>, |
| 68 | + private val layerWeights: List<AltUpLayerWeights<T>> |
| 69 | +) { |
| 70 | + |
| 71 | + private val numExtra = numInputs - 1 // 3 for E4B |
| 72 | + |
| 73 | + /** |
| 74 | + * Initialize AltUp states from a single embedding vector. |
| 75 | + * |
| 76 | + * The active state (idx 0) is the embedding itself. |
| 77 | + * Additional states are created by projecting the embedding through altup_proj slices. |
| 78 | + * |
| 79 | + * @param embedding The token embedding [hiddenSize] |
| 80 | + * @return List of [numInputs] state tensors |
| 81 | + */ |
| 82 | + public fun initialize(embedding: Tensor<T, Float>): List<Tensor<T, Float>> { |
| 83 | + val states = mutableListOf(embedding) |
| 84 | + |
| 85 | + // Project embedding into additional states using altup_proj [hiddenSize, hiddenSize, numExtra] |
| 86 | + val projBuf = globalWeights.projWeight.expectFloatBuffer() |
| 87 | + val embBuf = embedding.expectFloatBuffer() |
| 88 | + val h = hiddenSize |
| 89 | + |
| 90 | + for (k in 0 until numExtra) { |
| 91 | + val out = FloatArray(h) |
| 92 | + val offset = k * h * h |
| 93 | + for (i in 0 until h) { |
| 94 | + var sum = 0f |
| 95 | + for (j in 0 until h) { |
| 96 | + sum += projBuf[offset + i * h + j] * embBuf[j] |
| 97 | + } |
| 98 | + out[i] = sum |
| 99 | + } |
| 100 | + states.add(ctx.fromFloatArray<T, Float>(embedding.shape, dtype, out)) |
| 101 | + } |
| 102 | + |
| 103 | + return states |
| 104 | + } |
| 105 | + |
| 106 | + /** |
| 107 | + * Predict phase: generate predictions for all states using per-layer coefficients. |
| 108 | + * |
| 109 | + * Uses the router to compute routing logits, then applies predict_coef to |
| 110 | + * create weighted combinations of states. |
| 111 | + * |
| 112 | + * @param layerIdx Layer index to get per-layer weights |
| 113 | + * @param states Current parallel states |
| 114 | + * @return Predicted states |
| 115 | + */ |
| 116 | + public fun predict(layerIdx: Int, states: List<Tensor<T, Float>>): List<Tensor<T, Float>> { |
| 117 | + val lw = layerWeights[layerIdx] |
| 118 | + val coeffBuf = lw.predictCoef.expectFloatBuffer() |
| 119 | + // predict_coef shape: [numInputs, numInputs * numInputs] |
| 120 | + // For each output state i, coefficients for combining input states |
| 121 | + val n = numInputs |
| 122 | + |
| 123 | + return List(n) { i -> |
| 124 | + var result = states[i] |
| 125 | + for (j in 0 until n) { |
| 126 | + if (i != j) { |
| 127 | + // Use coefficient from the flattened matrix |
| 128 | + val coeff = coeffBuf[i * n + j] |
| 129 | + if (coeff != 0f) { |
| 130 | + result = addScaled(result, states[j], coeff) |
| 131 | + } |
| 132 | + } |
| 133 | + } |
| 134 | + result |
| 135 | + } |
| 136 | + } |
| 137 | + |
| 138 | + /** |
| 139 | + * Correct phase: update all states after the active state passes through the layer. |
| 140 | + * |
| 141 | + * innovation = layerOutput - predictions[activeIdx] |
| 142 | + * corrected[i] = predictions[i] + coeff[i, activeIdx] * (correctScale * innovation) |
| 143 | + * |
| 144 | + * @param layerIdx Layer index |
| 145 | + * @param layerOutput Output of the transformer layer for the active state |
| 146 | + * @param predictions Predicted states from [predict] |
| 147 | + * @return Corrected states |
| 148 | + */ |
| 149 | + public fun correct( |
| 150 | + layerIdx: Int, |
| 151 | + layerOutput: Tensor<T, Float>, |
| 152 | + predictions: List<Tensor<T, Float>> |
| 153 | + ): List<Tensor<T, Float>> { |
| 154 | + val lw = layerWeights[layerIdx] |
| 155 | + val innovation = addScaled(layerOutput, predictions[activeIdx], -1f) |
| 156 | + |
| 157 | + // Apply element-wise scale to innovation |
| 158 | + val scaleBuf = lw.correctScale.expectFloatBuffer() |
| 159 | + val innBuf = innovation.expectFloatBuffer() |
| 160 | + val scaledInnovation = FloatArray(innBuf.size) { innBuf[it] * scaleBuf[it % scaleBuf.size] } |
| 161 | + val scaledInnovationTensor = ctx.fromFloatArray<T, Float>(innovation.shape, dtype, scaledInnovation) |
| 162 | + |
| 163 | + val coeffBuf = lw.correctCoef.expectFloatBuffer() |
| 164 | + val n = numInputs |
| 165 | + |
| 166 | + return List(n) { i -> |
| 167 | + if (i == activeIdx) { |
| 168 | + layerOutput |
| 169 | + } else { |
| 170 | + val coeff = coeffBuf[i * n + activeIdx] |
| 171 | + addScaled(predictions[i], scaledInnovationTensor, coeff) |
| 172 | + } |
| 173 | + } |
| 174 | + } |
| 175 | + |
| 176 | + /** |
| 177 | + * Finalize: combine all states into a single output using altup_unembd_proj. |
| 178 | + * |
| 179 | + * output = states[activeIdx] + sum over k of (unembd_proj[k] @ states[k+1]) |
| 180 | + * |
| 181 | + * @param states Final parallel states after all layers |
| 182 | + * @return Combined output tensor |
| 183 | + */ |
| 184 | + public fun finalize(states: List<Tensor<T, Float>>): Tensor<T, Float> { |
| 185 | + val unprojBuf = globalWeights.unembdProjWeight.expectFloatBuffer() |
| 186 | + val h = hiddenSize |
| 187 | + var result = states[activeIdx].expectFloatBuffer().copyOf() |
| 188 | + |
| 189 | + // Add projected extra states |
| 190 | + for (k in 0 until numExtra) { |
| 191 | + val stateBuf = states[k + 1].expectFloatBuffer() |
| 192 | + val offset = k * h * h |
| 193 | + for (i in 0 until h) { |
| 194 | + var sum = 0f |
| 195 | + for (j in 0 until h) { |
| 196 | + sum += unprojBuf[offset + i * h + j] * stateBuf[j] |
| 197 | + } |
| 198 | + result[i] += sum |
| 199 | + } |
| 200 | + } |
| 201 | + |
| 202 | + return ctx.fromFloatArray<T, Float>(states[activeIdx].shape, dtype, result) |
| 203 | + } |
| 204 | + |
| 205 | + private fun addScaled(a: Tensor<T, Float>, b: Tensor<T, Float>, bScale: Float): Tensor<T, Float> { |
| 206 | + val aBuf = a.expectFloatBuffer() |
| 207 | + val bBuf = b.expectFloatBuffer() |
| 208 | + val out = FloatArray(aBuf.size) { aBuf[it] + bScale * bBuf[it] } |
| 209 | + return ctx.fromFloatArray<T, Float>(a.shape, dtype, out) |
| 210 | + } |
| 211 | + |
| 212 | + private fun Tensor<T, Float>.expectFloatBuffer(): FloatArray { |
| 213 | + val data = this.data |
| 214 | + if (data is sk.ainet.lang.tensor.data.FloatArrayTensorData<*>) return data.buffer |
| 215 | + return data.copyToFloatArray() |
| 216 | + } |
| 217 | +} |
0 commit comments