From 36043f562c6c652e7155a1c95ebc05a8b479d544 Mon Sep 17 00:00:00 2001 From: Adam Bien Date: Sun, 12 Apr 2026 12:34:16 +0200 Subject: [PATCH] Devstral 2 support (Mistral 3 architecture, Tekken tokenizer, YaRN RoPE) --- .../gpullama3/inference/InferenceCore.java | 90 ++++++ .../gpullama3/inference/operation/RoPE.java | 46 +++ .../inference/state/DevstralState.java | 77 +++++ .../beehive/gpullama3/model/ModelType.java | 8 + .../gpullama3/model/devstral/Devstral.java | 73 +++++ .../model/devstral/DevstralConfiguration.java | 56 ++++ .../model/devstral/package-info.java | 23 ++ .../gpullama3/model/format/ChatFormat.java | 2 + .../model/format/DevstralChatFormat.java | 82 ++++++ .../model/loader/DevstralModelLoader.java | 177 ++++++++++++ .../gpullama3/model/loader/ModelLoader.java | 2 + .../tokenizer/DevstralTokenizer.java | 237 ++++++++++++++++ .../gpullama3/tokenizer/Vocabulary.java | 5 + .../TransformerComputeKernelsLayered.java | 262 ++++++++++++++++++ .../QuantizationPlannerFactory.java | 5 + .../model/fp16/DevstralFP16LayerPlanner.java | 20 ++ .../model/q8_0/DevstralQ8_0LayerPlanner.java | 20 ++ .../type/fp16/DevstralFP16FFNLayers.java | 200 +++++++++++++ .../type/q8_0/DevstralQ8_0FFNLayers.java | 198 +++++++++++++ 19 files changed, 1583 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/inference/state/DevstralState.java create mode 100644 src/main/java/org/beehive/gpullama3/model/devstral/Devstral.java create mode 100644 src/main/java/org/beehive/gpullama3/model/devstral/DevstralConfiguration.java create mode 100644 src/main/java/org/beehive/gpullama3/model/devstral/package-info.java create mode 100644 src/main/java/org/beehive/gpullama3/model/format/DevstralChatFormat.java create mode 100644 src/main/java/org/beehive/gpullama3/model/loader/DevstralModelLoader.java create mode 100644 src/main/java/org/beehive/gpullama3/tokenizer/DevstralTokenizer.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index b679c25b..9beade35 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -12,6 +12,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; @@ -179,6 +180,95 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p return state.logits; } + /** + * Forward pass for Devstral 2 models where head_dim != dim/num_heads. + * Q projection outputs qDim (num_heads * head_dim) instead of dim. + */ + public static FloatTensor forwardJavaDevstral(Model model, State state, int token, int position) { + final DevstralConfiguration config = (DevstralConfiguration) model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); // 128 (independent head_dim) + int qDim = config.qDim(); // 4096 = 32 * 128 + int kvDim = config.kvDim(); // 1024 = 8 * 128 + int kvMul = config.kvMul(); + float sqrtHeadSize = (float) Math.sqrt(headSize); + + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + + for (int l = 0; l < config.numberOfLayers(); l++) { + rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps()); + + weights.wq[l].matmul(state.xb, state.q, qDim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // RoPE over qDim (not dim) + for (int i = 0; i < qDim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int v = 0; v < rotn; v++) { + FloatTensor vec = v == 0 ? state.q : state.k; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + + int curLayer = l; + + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= position; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + score /= sqrtHeadSize; + state.att.setFloat(attOffset + t, score); + } + + state.att.softmaxInPlace(attOffset, position + 1); + + int xbOffset = h * headSize; + state.xb.fillInPlace(xbOffset, headSize, 0f); + + for (int t = 0; t <= position; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // O projection: input qDim, output dim + weights.wo[l].matmul(state.xb, state.xb2, dim, qDim); + + state.x.addInPlace(state.xb2); + + rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps()); + + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + state.hb.multiplyInPlace(state.hb2); + + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + state.x.addInPlace(state.xb); + } + + rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps()); + weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim); + + return state.logits; + } + public static FloatTensor forwardJavaQwen2(Model model, State state, int token, int position) { final Qwen2Configuration config = (Qwen2Configuration) model.configuration(); final Qwen2StandardWeights weights = (Qwen2StandardWeights) model.weights(); diff --git a/src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java b/src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java index 5ed39eda..91465939 100644 --- a/src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java +++ b/src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java @@ -35,4 +35,50 @@ public static Pair precomputeFreqsCis(int contextLength, int h assert contextLength * (headSize / 2) == n; return new Pair<>(cr, ci); } + + public static Pair precomputeFreqsCisYaRN(int contextLength, int headSize, double theta, + float factor, float betaFast, float betaSlow, float logMultiplier, int originalContextLength) { + assert headSize % 2 == 0; + float[] cr = new float[contextLength * (headSize / 2)]; + float[] ci = new float[contextLength * (headSize / 2)]; + + float freqScale = 1.0f / factor; + + // Compute correlation dimensions for ramp interpolation + float corrDim0 = yarnCorrDim(headSize, originalContextLength, betaFast, (float) theta); + float corrDim1 = yarnCorrDim(headSize, originalContextLength, betaSlow, (float) theta); + + // Compute mscale (attention scaling for extended context) + // Formula: mscale = 0.1 * logMultiplier * log(factor) + 1.0 + float mscale = logMultiplier > 0 + ? 1.0f + 0.1f * logMultiplier * (float) Math.log(1.0f / freqScale) + : 1.0f; + + int n = 0; + for (int pos = 0; pos < contextLength; ++pos) { + for (int i = 0; i < headSize; i += 2) { + float freqExtrap = (float) (1.0 / Math.pow(theta, i / (double) headSize)); + float freqInterp = freqScale * freqExtrap; + + float rampMix = yarnRamp(corrDim0, corrDim1, i / 2); + float freq = freqInterp * (1.0f - rampMix) + freqExtrap * rampMix; + + float val = pos * freq; + cr[n] = (float) Math.cos(val) * mscale; + ci[n] = (float) Math.sin(val) * mscale; + n++; + } + } + assert contextLength * (headSize / 2) == n; + return new Pair<>(cr, ci); + } + + private static float yarnCorrDim(int nDims, int nCtxOrig, float nRot, float base) { + return nDims * (float) Math.log(nCtxOrig / (nRot * 2.0f * (float) Math.PI)) / (2.0f * (float) Math.log(base)); + } + + private static float yarnRamp(float low, float high, int i0) { + float y = (i0 - low) / Math.max(0.001f, high - low); + return 1.0f - Math.min(1.0f, Math.max(0.0f, y)); + } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/inference/state/DevstralState.java b/src/main/java/org/beehive/gpullama3/inference/state/DevstralState.java new file mode 100644 index 00000000..270c591f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/state/DevstralState.java @@ -0,0 +1,77 @@ +package org.beehive.gpullama3.inference.state; + +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +import java.util.stream.Stream; + +/** + * State for Devstral 2 models where head_dim != dim/num_heads. + * Allocates Q with qDim (num_heads * head_dim) and K/V with kvDim (num_kv_heads * head_dim). + */ +public final class DevstralState extends State { + + public DevstralState(Configuration config, int batchsize) { + super(config, batchsize); + } + + @Override + protected StateFields createStateFields(Configuration config) { + DevstralConfiguration dc = (DevstralConfiguration) config; + StateFields fields = new StateFields(); + + int qDim = dc.qDim(); + int kvDim = dc.kvDim(); + + fields.x = ArrayFloatTensor.allocate(dc.dim()); + fields.xb = ArrayFloatTensor.allocate(dc.dim()); + fields.xb2 = ArrayFloatTensor.allocate(dc.dim()); + fields.hb = ArrayFloatTensor.allocate(dc.hiddenDim()); + fields.hb2 = ArrayFloatTensor.allocate(dc.hiddenDim()); + fields.q = ArrayFloatTensor.allocate(qDim); + fields.k = ArrayFloatTensor.allocate(kvDim); + fields.v = ArrayFloatTensor.allocate(kvDim); + fields.att = ArrayFloatTensor.allocate(dc.numberOfHeads(), dc.contextLength()); + fields.logits = ArrayFloatTensor.allocate(dc.vocabularySize()); + + fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(dc.contextLength(), kvDim)).limit(dc.numberOfLayers()).toArray(FloatTensor[]::new); + fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(dc.contextLength(), kvDim)).limit(dc.numberOfLayers()).toArray(FloatTensor[]::new); + + // TornadoVM wrappers + fields.wrapX = new FloatArray(dc.dim()); + fields.wrapXb = new FloatArray(dc.dim()); + fields.wrapXb2 = new FloatArray(dc.dim()); + fields.wrapHb = new FloatArray(dc.hiddenDim()); + fields.wrapHb2 = new FloatArray(dc.hiddenDim()); + + switch (dc.quantization()) { + case "FP16" -> fields.createActivationFP16(dc.dim()); + case "Q8_0" -> fields.createActivationQ8_0(dc.dim()); + default -> throw new UnsupportedOperationException("Unsupported quantization format: " + dc.quantization()); + } + fields.wrapLogits = new FloatArray(dc.vocabularySize()); + fields.wrapQ = new FloatArray(qDim); + fields.wrapK = new FloatArray(kvDim); + fields.wrapV = new FloatArray(kvDim); + + fields.wrapXFP16 = new HalfFloatArray(dc.dim()); + fields.wrapXbFP16 = new HalfFloatArray(dc.dim()); + fields.wrapKeyCache = new FloatArray(dc.contextLength() * kvDim * dc.numberOfLayers()); + fields.wrapValueCache = new FloatArray(dc.contextLength() * kvDim * dc.numberOfLayers()); + fields.wrapValueCache.init(0.f); + fields.wrapKeyCache.init(0.f); + fields.wrapAtt = new FloatArray(dc.numberOfHeads() * dc.contextLength()); + fields.positionHolder = new IntArray(1); + + fields.temp = new FloatArray(1 + ((dc.dim() + localSize - 1) / localSize)); + fields.tempFFN = new FloatArray(1 + ((dc.dim() + localSize - 1) / localSize)); + fields.tempLogits = new FloatArray(1 + ((dc.dim() + localSize - 1) / localSize)); + + return fields; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/ModelType.java b/src/main/java/org/beehive/gpullama3/model/ModelType.java index fb46ff6e..0659da7d 100644 --- a/src/main/java/org/beehive/gpullama3/model/ModelType.java +++ b/src/main/java/org/beehive/gpullama3/model/ModelType.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.model; +import org.beehive.gpullama3.model.loader.DevstralModelLoader; import org.beehive.gpullama3.model.loader.GraniteLoader; import org.beehive.gpullama3.tensor.GGUF; import org.beehive.gpullama3.model.loader.LlamaModelLoader; @@ -37,6 +38,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo } }, + DEVSTRAL_2 { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new DevstralModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); + } + }, + QWEN_2 { @Override public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { diff --git a/src/main/java/org/beehive/gpullama3/model/devstral/Devstral.java b/src/main/java/org/beehive/gpullama3/model/devstral/Devstral.java new file mode 100644 index 00000000..68c28d69 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/devstral/Devstral.java @@ -0,0 +1,73 @@ +package org.beehive.gpullama3.model.devstral; + +import org.beehive.gpullama3.inference.InferenceCore; +import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.DevstralState; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.AbstractModel; +import org.beehive.gpullama3.model.ModelType; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.tokenizer.DevstralTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +public class Devstral extends AbstractModel { + + DevstralConfiguration configuration; + + public Devstral(DevstralConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { + super(tokenizer, weights, chatFormat, null); + this.configuration = configuration; + } + + @Override + public DevstralConfiguration configuration() { + return configuration; + } + + @Override + public DevstralTokenizer tokenizer() { + return (DevstralTokenizer) tokenizer; + } + + @Override + public ModelType getModelType() { + return ModelType.DEVSTRAL_2; + } + + public State createNewState() { + State state = new DevstralState(configuration(), -1); + state.latestToken = tokenizer.getSpecialTokens().get(""); + return state; + } + + public State createNewState(int batchsize) { + State state = new DevstralState(configuration(), batchsize); + state.latestToken = tokenizer.getSpecialTokens().get(""); + return state; + } + + @Override + public void forward(State state, int token, int position) { + InferenceCore.forwardJavaDevstral(this, state, token, position); + } + + @Override + public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } + + @Override + public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/model/devstral/DevstralConfiguration.java b/src/main/java/org/beehive/gpullama3/model/devstral/DevstralConfiguration.java new file mode 100644 index 00000000..3339a042 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/devstral/DevstralConfiguration.java @@ -0,0 +1,56 @@ +package org.beehive.gpullama3.model.devstral; + +import org.beehive.gpullama3.model.Configuration; + +/** + * Configuration for Devstral 2 models (Mistral 3 architecture). + * Unlike standard Mistral, Devstral 2 has an independent head dimension + * (head_dim != dim / num_heads), requiring explicit key_length/value_length. + */ +// @formatter:off +public record DevstralConfiguration(String quantization, + int dim, + int hiddenDim, + int numberOfLayers, + int numberOfHeads, + int numberOfKeyValueHeads, + int headDim, + int vocabularySize, + int contextLength, + float rmsNormEps, + float ropeTheta) implements Configuration { + + @Override public String quantization() { + return quantization; + } + + /** + * Q projection output dimension = numberOfHeads * headDim. + * This differs from dim when headDim != dim/numberOfHeads. + */ + public int qDim() { + return numberOfHeads * headDim; + } + + public int kvDim() { + return numberOfKeyValueHeads * headDim; + } + + public int kvMul() { + return numberOfHeads / numberOfKeyValueHeads; + } + + @Override + public int numberOfHeadsKey() { + throw new UnsupportedOperationException("Not supported for Devstral."); + } + + @Override + public int contextLengthModel() { + throw new UnsupportedOperationException("Not supported for Devstral."); + } + + public int headSize() { + return headDim; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/devstral/package-info.java b/src/main/java/org/beehive/gpullama3/model/devstral/package-info.java new file mode 100644 index 00000000..d7baea9b --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/devstral/package-info.java @@ -0,0 +1,23 @@ +/** + * Devstral 2 model support (Mistral 3 architecture). + * + *

Key architectural differences from standard Mistral

+ *
    + *
  • Independent head dimension: {@code head_dim=128 != dim/num_heads=160}. + * Q projection outputs {@code qDim = num_heads * head_dim = 4096}, not {@code dim = 5120}. + * K/V projection outputs {@code kvDim = num_kv_heads * head_dim = 1024}, not {@code 1280}.
  • + *
  • GGUF metadata prefix: {@code mistral3.*} (not {@code llama.*}). + * Head dimension read from {@code mistral3.attention.key_length}.
  • + *
  • Tekken tokenizer: GPT-2-style BPE with explicit merges list and {@code BYTE_ENCODER}, + * replacing Mistral's SentencePiece score-based BPE. Identified by {@code tokenizer.ggml.pre=tekken}.
  • + *
  • YaRN RoPE scaling: Precomputed frequencies with {@code mscale = 1 + 0.1 * logMultiplier * ln(factor)}, + * required for correct positional encoding even at short contexts.
  • + *
  • Non-square Q/K/V projections on GPU: Dedicated kernels ({@code fusedQKVMatmulQ8NonSquare}) + * that separate {@code inputDim} (5120) from {@code qDim} (4096).
  • + *
+ * + * @see DevstralConfiguration + * @see Devstral + * @see Devstral Small 2 24B GGUF + */ +package org.beehive.gpullama3.model.devstral; diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index d3466a8e..827ad625 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.model.format; +import org.beehive.gpullama3.tokenizer.DevstralTokenizer; import org.beehive.gpullama3.tokenizer.GraniteTokenizer; import org.beehive.gpullama3.tokenizer.LlamaTokenizer; import org.beehive.gpullama3.tokenizer.MistralTokenizer; @@ -13,6 +14,7 @@ public interface ChatFormat { static ChatFormat create(Object tokenizer, ChatTokens chatTokens) { return switch (tokenizer) { + case DevstralTokenizer devstralTokenizer -> new DevstralChatFormat(devstralTokenizer); case GraniteTokenizer graniteTokenizer -> new GraniteChatFormat(graniteTokenizer); case LlamaTokenizer llamaTokenizer -> new LlamaChatFormat(llamaTokenizer); case MistralTokenizer mistralTokenizer -> new MistralChatFormat(mistralTokenizer); diff --git a/src/main/java/org/beehive/gpullama3/model/format/DevstralChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/DevstralChatFormat.java new file mode 100644 index 00000000..85205f65 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/DevstralChatFormat.java @@ -0,0 +1,82 @@ +package org.beehive.gpullama3.model.format; + +import org.beehive.gpullama3.tokenizer.DevstralTokenizer; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class DevstralChatFormat implements ChatFormat { + + protected final DevstralTokenizer tokenizer; + protected final int unknownToken; + protected final int beginOfText; + protected final int endOfText; + protected final int beginOfInstruction; + protected final int endOfInstruction; + protected final int toolCalls; + protected final int beginOfAvailableTools; + protected final int endOfAvailableTools; + protected final int beginOfToolResults; + protected final int endOfToolResults; + protected final int prefix; + protected final int middle; + protected final int suffix; + + public DevstralChatFormat(DevstralTokenizer tokenizer) { + this.tokenizer = tokenizer; + Map specialTokens = tokenizer.getSpecialTokens(); + this.unknownToken = specialTokens.getOrDefault("", -1); + this.beginOfText = specialTokens.get(""); + this.endOfText = specialTokens.get(""); + this.beginOfInstruction = specialTokens.get("[INST]"); + this.endOfInstruction = specialTokens.get("[/INST]"); + this.toolCalls = specialTokens.getOrDefault("[TOOL_CALLS]", unknownToken); + this.beginOfAvailableTools = specialTokens.getOrDefault("[AVAILABLE_TOOLS]", unknownToken); + this.endOfAvailableTools = specialTokens.getOrDefault("[/AVAILABLE_TOOLS]", unknownToken); + this.beginOfToolResults = specialTokens.getOrDefault("[TOOL_RESULTS]", unknownToken); + this.endOfToolResults = specialTokens.getOrDefault("[/TOOL_RESULTS]", unknownToken); + this.prefix = specialTokens.getOrDefault("[PREFIX]", unknownToken); + this.suffix = specialTokens.getOrDefault("[SUFFIX]", unknownToken); + this.middle = specialTokens.getOrDefault("[MIDDLE]", unknownToken); + } + + @Override + public int getBeginOfText() { + return beginOfText; + } + + @Override + public Set getStopTokens() { + return Set.of(endOfText); + } + + @Override + public List encodeHeader(Message message) { + List tokens = new ArrayList<>(); + tokens.add(beginOfInstruction); + tokens.addAll(tokenizer.encodeAsList(message.role().name())); + tokens.add(endOfInstruction); + return tokens; + } + + @Override + public List encodeMessage(Message message) { + List tokens = encodeHeader(message); + tokens.addAll(tokenizer.encodeAsList(message.content().strip())); + tokens.add(endOfInstruction); + return tokens; + } + + public List encodeFillInTheMiddle(String prefix, String suffix) { + List tokens = new ArrayList<>(); + final Set EMPTY_STRING_SET = Collections.emptySet(); + tokens.add(this.suffix); + tokens.addAll(tokenizer.encode(suffix, EMPTY_STRING_SET)); + tokens.add(this.prefix); + tokens.addAll(tokenizer.encode(prefix, EMPTY_STRING_SET)); + return tokens; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/loader/DevstralModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/DevstralModelLoader.java new file mode 100644 index 00000000..8c230b2f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/loader/DevstralModelLoader.java @@ -0,0 +1,177 @@ +package org.beehive.gpullama3.model.loader; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.auxiliary.Pair; +import org.beehive.gpullama3.inference.operation.RoPE; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.devstral.Devstral; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import org.beehive.gpullama3.tokenizer.DevstralTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +import java.nio.channels.FileChannel; +import java.util.Map; + +import static org.beehive.gpullama3.model.loader.ModelLoader.*; + +public class DevstralModelLoader extends AbstractModelLoader { + + public DevstralModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, useTornadovm); + } + + @Override + protected Vocabulary loadVocabulary(Map metadata) { + return Vocabulary.loadDevstralVocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + return new DevstralTokenizer(metadata, vocabulary); + } + + // @formatter:off + @Override + protected DevstralConfiguration createConfiguration(Map metadata) { + String prefix = "mistral3"; + + int modelContextLength = (int) metadata.get(prefix + ".context_length"); + int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength; + + int vocabSize = metadata.containsKey(prefix + ".vocab_size") ? (int) metadata.get(prefix + ".vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length"); + + // Devstral 2 has independent head dimension (head_dim != dim/num_heads) + int headDim = (int) metadata.get(prefix + ".attention.key_length"); + + return new DevstralConfiguration( + getModelQuantization(metadata), + (int) metadata.get(prefix + ".embedding_length"), + (int) metadata.get(prefix + ".feed_forward_length"), + (int) metadata.get(prefix + ".block_count"), + (int) metadata.get(prefix + ".attention.head_count"), + metadata.containsKey(prefix + ".attention.head_count_kv") ? + (int) metadata.get(prefix + ".attention.head_count_kv") + : (int) metadata.get(prefix + ".attention.head_count"), + headDim, + vocabSize, + finalContextLength, + (float) metadata.getOrDefault(prefix + ".attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault(prefix + ".rope.freq_base", 10000f) + ); + } + // @formatter:on + + // @formatter:off + @Override + protected Pair precomputeRopeFrequencies(DevstralConfiguration config) { + Map metadata = gguf.getMetadata(); + String prefix = "mistral3"; + + String ropeScalingType = (String) metadata.getOrDefault(prefix + ".rope.scaling.type", ""); + if ("yarn".equals(ropeScalingType)) { + float factor = (float) metadata.get(prefix + ".rope.scaling.factor"); + float betaFast = (float) metadata.get(prefix + ".rope.scaling.yarn_beta_fast"); + float betaSlow = (float) metadata.get(prefix + ".rope.scaling.yarn_beta_slow"); + float logMultiplier = (float) metadata.getOrDefault(prefix + ".rope.scaling.yarn_log_multiplier", 0.0f); + int originalContextLength = (int) metadata.get(prefix + ".rope.scaling.original_context_length"); + + return RoPE.precomputeFreqsCisYaRN( + config.contextLength(), + config.headDim(), + config.ropeTheta(), + factor, + betaFast, + betaSlow, + logMultiplier, + originalContextLength + ); + } + + return RoPE.precomputeFreqsCis( + config.contextLength(), + config.headDim(), + config.ropeTheta(), + false, + 1.0f, + 1.0f, + 1.0f, + config.contextLength() + ); + } + // @formatter:on + + @Override + protected Devstral createModel(DevstralConfiguration config, Tokenizer tokenizer, Weights weights) { + return new Devstral(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); + } + + // @formatter:off + @Override + protected Weights createStandardWeights(Map tensorEntries, DevstralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + + final int nl = config.numberOfLayers(); + + return new LlamaStandardWeights( + loadTensor(tokenEmbeddings), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTensor(tensorEntries.get("output_norm.weight")), + new ArrayFloatTensor(ropeFreqs.first()), + new ArrayFloatTensor(ropeFreqs.second()), + loadTensor(outputWeight), + outputWeight.ggmlType()); + } + // @formatter:on + + // @formatter:off + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, DevstralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + GGMLType ggmlType = outputWeight.ggmlType(); + + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); + } + + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + } + + final int nl = config.numberOfLayers(); + + return new LlamaTornadoWeights( + loadTornadoTensor(tokenEmbeddings), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTornadoTensor(tensorEntries.get("output_norm.weight")), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())), + loadTornadoTensor(outputWeight), + ggmlType + ); + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index 392113be..83b25987 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -50,6 +50,8 @@ private static ModelType detectModelType(Map metadata) { String lowerName = name.toLowerCase(); if (lowerName.contains("granite")) { return ModelType.GRANITE; + } else if (lowerName.contains("devstral")) { + return ModelType.DEVSTRAL_2; } else if (lowerName.contains("mistral")) { return ModelType.MISTRAL; } else if (lowerName.contains("llama")) { diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/DevstralTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/DevstralTokenizer.java new file mode 100644 index 00000000..5c74fd8a --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tokenizer/DevstralTokenizer.java @@ -0,0 +1,237 @@ +package org.beehive.gpullama3.tokenizer; + +import org.beehive.gpullama3.auxiliary.Pair; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * GPT-2-style BPE tokenizer for Devstral 2 models (Tekken tokenizer). + *

+ * Tekken is a Tiktoken-based BPE tokenizer with explicit merges list and byte-level encoding, + * unlike the score-based SentencePiece BPE used by earlier Mistral models. + */ +public class DevstralTokenizer implements Tokenizer { + static final Map BYTE_ENCODER = bytesToUnicode(); + static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + + private static final String TEKKEN_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + private final Pattern compiledPattern; + private final Vocabulary vocabulary; + private final Map, Integer> merges; + private final Map specialTokens; + private final int[] tokenTypes; + + public DevstralTokenizer(Map metadata, Vocabulary vocabulary) { + this.vocabulary = vocabulary; + this.compiledPattern = Pattern.compile(TEKKEN_PATTERN); + + // Load token types for special token detection + this.tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type"); + + // Build special tokens map from token types (type != 1 and type != 6 are special) + Map specialTokens = new HashMap<>(); + if (tokenTypes != null) { + for (int i = 0; i < vocabulary.size(); i++) { + if (tokenTypes[i] != 1 && tokenTypes[i] != 6) { + specialTokens.put(vocabulary.get(i), i); + } + } + } else { + // Fallback: detect known Mistral special tokens by name + for (int i = 0; i < vocabulary.size(); i++) { + String token = vocabulary.get(i); + if (token.equals("") || token.equals("") || token.equals("") + || (token.startsWith("[") && token.endsWith("]"))) { + specialTokens.put(token, i); + } + } + } + this.specialTokens = Map.copyOf(specialTokens); + + // Load merges + String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); + List> mergeList = Arrays.stream(mergeLines).map(line -> line.split(" ")) + .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); + + this.merges = new HashMap<>(); + for (Pair pair : mergeList) { + String merged = vocabulary.get(pair.first()) + vocabulary.get(pair.second()); + int mergeIndex = vocabulary.getIndex(merged).orElseThrow(); + this.merges.put(pair, mergeIndex); + } + } + + private static List findAll(Pattern pattern, String text) { + List matches = new ArrayList<>(); + Matcher matcher = pattern.matcher(text); + while (matcher.find()) { + matches.add(matcher.group()); + } + return matches; + } + + private static List merge(List ids, Pair pair, int idx) { + List newIds = new ArrayList<>(); + int i = 0; + while (i < ids.size()) { + if (i < ids.size() - 1 && ids.get(i).equals(pair.first()) && ids.get(i + 1).equals(pair.second())) { + newIds.add(idx); + i += 2; + } else { + newIds.add(ids.get(i)); + i++; + } + } + return newIds; + } + + private static Map bytesToUnicode() { + List bs = new ArrayList<>(); + IntStream.rangeClosed('!', '~').forEach(bs::add); + IntStream.rangeClosed('\u00A1', '\u00AC').forEach(bs::add); + IntStream.rangeClosed('\u00AE', '\u00FF').forEach(bs::add); + + List cs = new ArrayList<>(bs); + int n = 0; + for (int b = 0; b < 256; b++) { + if (!bs.contains(b)) { + bs.add(b); + cs.add(256 + n); + n++; + } + } + return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get)); + } + + @Override + public String regexPattern() { + return compiledPattern != null ? compiledPattern.pattern() : null; + } + + @Override + public Map getSpecialTokens() { + return specialTokens; + } + + @Override + public boolean isSpecialToken(int tokenIndex) { + if (tokenTypes != null && tokenIndex >= 0 && tokenIndex < tokenTypes.length) { + return tokenTypes[tokenIndex] != 1 && tokenTypes[tokenIndex] != 6; + } + return specialTokens.containsValue(tokenIndex); + } + + @Override + public boolean shouldDisplayToken(int token) { + if (tokenTypes != null && token >= 0 && token < tokenTypes.length) { + return tokenTypes[token] == 1 || tokenTypes[token] == 6; + } + return !isSpecialToken(token); + } + + public int[] encode(String text) { + StringBuilder sb = new StringBuilder(); + byte[] bytes = text.getBytes(StandardCharsets.UTF_8); + for (byte b : bytes) { + sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + } + return encodeImpl(sb.toString()); + } + + @Override + public List encodeAsList(String text) { + return Arrays.stream(encode(text)).boxed().toList(); + } + + private int[] encodeImpl(String text) { + return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); + } + + @Override + public List encode(String text, Set allowedSpecial) { + if (allowedSpecial.isEmpty()) { + return encodeOrdinary(text); + } + + assert specialTokens.keySet().containsAll(allowedSpecial); + String specialPattern = allowedSpecial.stream().map(Pattern::quote).collect(Collectors.joining("|", "(", ")")); + String[] specialChunks = text.split(specialPattern); + + List ids = new ArrayList<>(); + for (String part : specialChunks) { + if (allowedSpecial.contains(part)) { + ids.add(specialTokens.get(part)); + } else { + ids.addAll(encodeOrdinary(part)); + } + } + return ids; + } + + public List encodeOrdinary(String text) { + List textChunks = findAll(compiledPattern, text); + List ids = new ArrayList<>(); + for (String chunk : textChunks) { + ids.addAll(encodeChunk(chunk)); + } + return ids; + } + + private List encodeChunk(String chunk) { + List ids = new ArrayList<>(); + for (char c : chunk.toCharArray()) { + int tokenIndex = vocabulary.getIndex(String.valueOf(c)).orElseThrow(); + ids.add(tokenIndex); + } + + while (ids.size() >= 2) { + Map, Integer> stats = getStats(ids); + Pair pair = stats.keySet().stream().min(Comparator.comparingInt(key -> merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow(); + if (!merges.containsKey(pair)) { + break; + } + ids = merge(ids, pair, merges.get(pair)); + } + return ids; + } + + @Override + public String decode(List tokens) { + String decoded = decodeImpl(tokens); + int[] decodedBytesAsInts = decoded.codePoints().map(cp -> BYTE_DECODER.getOrDefault(cp, cp)).toArray(); + byte[] rawBytes = new byte[decodedBytesAsInts.length]; + for (int i = 0; i < decodedBytesAsInts.length; i++) { + rawBytes[i] = (byte) decodedBytesAsInts[i]; + } + return new String(rawBytes, StandardCharsets.UTF_8); + } + + private String decodeImpl(List tokens) { + StringBuilder sb = new StringBuilder(); + for (int token : tokens) { + sb.append(vocabulary.get(token)); + } + return sb.toString(); + } + + private Map, Integer> getStats(List ids) { + Map, Integer> map = new HashMap<>(); + for (int i = 0; i + 1 < ids.size(); i++) { + Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); + map.merge(key, 1, Integer::sum); + } + return map; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java index 1a867569..b29f3576 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java @@ -36,6 +36,11 @@ public static Vocabulary loadQwen3Vocabulary(Map metadata) { return new Vocabulary(tokens, scores); } + public static Vocabulary loadDevstralVocabulary(Map metadata) { + String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); + return new Vocabulary(tokens, null); + } + public static Vocabulary loadPhi3Vocabulary(Map metadata) { String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); float[] scores = (float[]) metadata.get("tokenizer.ggml.scores"); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index d1803e41..11688a6b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -441,6 +441,51 @@ public static void ropeRotationWithCacheCopy(KernelContext context, IntArray pos } + /** + * RoPE rotation using precomputed frequency tables (cos/sin) instead of on-the-fly computation. + * Required for models with non-standard RoPE (e.g., YaRN scaling in Devstral 2). + */ + public static void ropeRotationWithCacheCopyPrecomputed(KernelContext context, IntArray positionHolder, FloatArray sq, + FloatArray sk, FloatArray sv, + FloatArray keyCache, FloatArray valueCache, + FloatArray freqCisReal, FloatArray freqCisImag, + int kvDim, int headSize, int layer, int contextLength) { + + int i = context.globalIdx * 2; + int pos = positionHolder.get(0); + + if (i + 1 < sq.getSize()) { + int head_dim = i % headSize; + int freqIdx = pos * (headSize / 2) + (head_dim / 2); + float fcr = freqCisReal.get(freqIdx); + float fci = freqCisImag.get(freqIdx); + + // Rotate Q + float v0q = sq.get(i); + float v1q = sq.get(i + 1); + sq.set(i, v0q * fcr - v1q * fci); + sq.set(i + 1, v0q * fci + v1q * fcr); + + // Rotate K AND write to cache + if (i + 1 < kvDim) { + float v0k = sk.get(i); + float v1k = sk.get(i + 1); + float rotated0 = v0k * fcr - v1k * fci; + float rotated1 = v0k * fci + v1k * fcr; + + sk.set(i, rotated0); + sk.set(i + 1, rotated1); + + int cacheOffset = layer * contextLength * kvDim + pos * kvDim; + keyCache.set(cacheOffset + i, rotated0); + keyCache.set(cacheOffset + i + 1, rotated1); + + valueCache.set(cacheOffset + i, sv.get(i)); + valueCache.set(cacheOffset + i + 1, sv.get(i + 1)); + } + } + } + public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArray v, int dimQ, int dimKV) { int totalSize = dimQ + 2 * dimKV; @@ -1280,6 +1325,66 @@ public static void fusedQKVMatmulX( } } + /** + * Fused QKV matmul for FP16 models where Q output dim != input dim. + */ + public static void fusedQKVMatmulXNonSquare( + KernelContext context, + HalfFloatArray x, FloatArray q, FloatArray k, FloatArray v, + HalfFloatArray wq, HalfFloatArray wk, HalfFloatArray wv, + int inputDim, int qDim, int kvDim, int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < qDim) { + int rowOffset = rowId * inputDim; + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partialSum += wq.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + localSum[localId] = partialSum; + context.localBarrier(); + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { localSum[localId] += localSum[localId + stride]; } + context.localBarrier(); + } + if (localId == 0) { q.set(rowId, localSum[0]); } + + } else if (rowId < qDim + kvDim) { + int kRow = rowId - qDim; + int rowOffset = kRow * inputDim; + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partialSum += wk.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + localSum[localId] = partialSum; + context.localBarrier(); + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { localSum[localId] += localSum[localId + stride]; } + context.localBarrier(); + } + if (localId == 0) { k.set(kRow, localSum[0]); } + + } else if (rowId < qDim + 2 * kvDim) { + int vRow = rowId - qDim - kvDim; + int rowOffset = vRow * inputDim; + float partialSum = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partialSum += wv.get(rowOffset + j).getFloat32() * x.get(j).getFloat32(); + } + localSum[localId] = partialSum; + context.localBarrier(); + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { localSum[localId] += localSum[localId + stride]; } + context.localBarrier(); + } + if (localId == 0) { v.set(vRow, localSum[0]); } + } + } + // @formatter:off public static void matrixVectorGeneric( KernelContext context, @@ -2504,6 +2609,163 @@ public static void fusedQKVMatmulQ8(KernelContext context, FloatArray x, FloatAr } } + /** + * Fused QKV matmul for models where Q output dim != input dim (e.g., Devstral 2). + * Separates inputDim (embedding size) from qDim (num_heads * head_dim). + */ + public static void fusedQKVMatmulQ8NonSquare(KernelContext context, FloatArray x, FloatArray q, FloatArray k, FloatArray v, + ByteArray wq, ByteArray wk, ByteArray wv, int inputDim, int qDim, int kvDim, int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + + int blockSize = 32; + final int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (inputDim + blockSize - 1) / blockSize; + + float[] localSums = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowId < qDim) { + // ========== Q projection ========== + int rowBlockOffset = rowId * blocksPerRow; + + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + for (int j = localId * 4; j < inputDim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + + float scaleFloat = wq.getHalfFloat(blockByteOffset).getFloat32(); + + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + partialSum1 += ((float) wq.get(quantsOffset)) * scaleFloat * x.get(j); + partialSum2 += ((float) wq.get(quantsOffset + 1)) * scaleFloat * x.get(j + 1); + partialSum3 += ((float) wq.get(quantsOffset + 2)) * scaleFloat * x.get(j + 2); + partialSum4 += ((float) wq.get(quantsOffset + 3)) * scaleFloat * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + for (int j = ((inputDim / 4) * 4) + localId; j < inputDim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + float scaleFloat = wq.getHalfFloat(blockByteOffset).getFloat32(); + partialSum += ((float) wq.get(blockByteOffset + 2 + withinBlockIdx)) * scaleFloat * x.get(j); + } + + localSums[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + q.set(rowId, localSums[0]); + } + + } else if (rowId < qDim + kvDim) { + // ========== K projection ========== + int kRow = rowId - qDim; + int rowBlockOffset = kRow * blocksPerRow; + + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + for (int j = localId * 4; j < inputDim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + float scaleFloat = wk.getHalfFloat(blockByteOffset).getFloat32(); + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + partialSum1 += ((float) wk.get(quantsOffset)) * scaleFloat * x.get(j); + partialSum2 += ((float) wk.get(quantsOffset + 1)) * scaleFloat * x.get(j + 1); + partialSum3 += ((float) wk.get(quantsOffset + 2)) * scaleFloat * x.get(j + 2); + partialSum4 += ((float) wk.get(quantsOffset + 3)) * scaleFloat * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + for (int j = ((inputDim / 4) * 4) + localId; j < inputDim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + float scaleFloat = wk.getHalfFloat(blockByteOffset).getFloat32(); + partialSum += ((float) wk.get(blockByteOffset + 2 + withinBlockIdx)) * scaleFloat * x.get(j); + } + + localSums[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + k.set(kRow, localSums[0]); + } + + } else if (rowId < qDim + 2 * kvDim) { + // ========== V projection ========== + int vRow = rowId - qDim - kvDim; + int rowBlockOffset = vRow * blocksPerRow; + + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + for (int j = localId * 4; j < inputDim - 3; j += localWorkGroupSize * 4) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + float scaleFloat = wv.getHalfFloat(blockByteOffset).getFloat32(); + int quantsOffset = blockByteOffset + 2 + withinBlockIdx; + partialSum1 += ((float) wv.get(quantsOffset)) * scaleFloat * x.get(j); + partialSum2 += ((float) wv.get(quantsOffset + 1)) * scaleFloat * x.get(j + 1); + partialSum3 += ((float) wv.get(quantsOffset + 2)) * scaleFloat * x.get(j + 2); + partialSum4 += ((float) wv.get(quantsOffset + 3)) * scaleFloat * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + for (int j = ((inputDim / 4) * 4) + localId; j < inputDim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlockIdx = j % blockSize; + int blockByteOffset = (rowBlockOffset + blockIdx) * Q8_0_BLOCK_BYTES; + float scaleFloat = wv.getHalfFloat(blockByteOffset).getFloat32(); + partialSum += ((float) wv.get(blockByteOffset + 2 + withinBlockIdx)) * scaleFloat * x.get(j); + } + + localSums[localId] = partialSum; + context.localBarrier(); + + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + if (localId == 0) { + v.set(vRow, localSums[0]); + } + } + } + /** * Fully fused RMS normalization + FFN W1/W3 matmul with SiLU/GLU for Q8_0 weights. * Each workgroup redundantly computes RMS scale to avoid cross-workgroup sync. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java index f3bc3d3a..42d2dc0c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.tornadovm.layerplanner; +import org.beehive.gpullama3.inference.state.DevstralState; import org.beehive.gpullama3.inference.state.GraniteState; import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.state.LlamaState; @@ -8,12 +9,14 @@ import org.beehive.gpullama3.inference.state.Qwen3State; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.DevstralFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.DevstralQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.MistralQ8_0LayerPlanner; @@ -57,6 +60,7 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) { return switch (model.getModelType()) { case LLAMA_3 -> new LlamaFP16LayerPlanner((LlamaState) state, model); case MISTRAL -> new MistralFP16LayerPlanner((LlamaState) state, model); + case DEVSTRAL_2 -> new DevstralFP16LayerPlanner((DevstralState) state, model); case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model); @@ -71,6 +75,7 @@ private static GenericLayerPlanner createQ8_0Planner(State state, Model model) { return switch (model.getModelType()) { case LLAMA_3 -> new LlamaQ8_0LayerPlanner((LlamaState) state, model); case MISTRAL -> new MistralQ8_0LayerPlanner((LlamaState) state, model); + case DEVSTRAL_2 -> new DevstralQ8_0LayerPlanner((DevstralState) state, model); case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java new file mode 100644 index 00000000..f72cfeb7 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java @@ -0,0 +1,20 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.inference.state.DevstralState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.DevstralFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; + +public class DevstralFP16LayerPlanner extends FP16LayerPlanner { + + public DevstralFP16LayerPlanner(DevstralState state, Model model) { + super(state, model); + this.activationLayer = new Activation("activationUpdate", state, weights, config); + this.ffnLayers = new DevstralFP16FFNLayers("devstralFFN", state, weights, config, schedulerType); + this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); + createTornadoInferencePlan(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java new file mode 100644 index 00000000..b6bf9634 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java @@ -0,0 +1,20 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; + +import org.beehive.gpullama3.inference.state.DevstralState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.DevstralQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; + +public class DevstralQ8_0LayerPlanner extends Q8_0LayerPlanner { + + public DevstralQ8_0LayerPlanner(DevstralState state, Model model) { + super(state, model); + this.activationLayer = new Activation("activationUpdate", state, weights, config); + this.ffnLayers = new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); + createTornadoInferencePlan(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java new file mode 100644 index 00000000..7fcc8717 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java @@ -0,0 +1,200 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * FP16 FFN layers for Devstral 2 models. + * Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation. + */ +public class DevstralFP16FFNLayers extends AbstractFFNLayers { + + public DevstralFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, DevstralConfiguration config, SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + setupFFNLayers(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + int fusedQKVRows = config.qDim() + 2 * config.kvDim(); + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.qDim() / 2, 512); + + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + return tornadoForwardScheduler; + } + + // @formatter:off + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantize, + context, state.wrapXbFP16, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulXNonSquare, + context, + state.wrapXbFP16, + state.wrapQ, state.wrapK, state.wrapV, + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + config.dim(), config.qDim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Use precomputed RoPE frequencies (YaRN-scaled) + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopyPrecomputed, + context, + state.positionHolder, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + weights.freq_cis_realFlat.asFloatArray(), + weights.freq_cis_imagFlat.asFloatArray(), + config.kvDim(), config.headSize(), layerIndex, config.contextLength()); + + configureAttention(unifiedLayer, layerIndex); + + // O projection: n=qDim (input), d=dim (output) + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asHalfFloatArray(), + config.qDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + state.wrapX, state.wrapHb, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.tempFFN, + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asHalfFloatArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, state.wrapXbFP16, + weights.freq_cis_realFlat.asFloatArray(), + weights.freq_cis_imagFlat.asFloatArray()); + } else { + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, state.wrapXbFP16, + weights.freq_cis_realFlat.asFloatArray(), + weights.freq_cis_imagFlat.asFloatArray()); + } + return unifiedLayer; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java new file mode 100644 index 00000000..0c5aedd6 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java @@ -0,0 +1,198 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Q8_0 FFN layers for Devstral 2 models. + * Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation. + */ +public class DevstralQ8_0FFNLayers extends AbstractFFNLayers { + + public DevstralQ8_0FFNLayers(String taskGraphName, State state, LlamaTornadoWeights weights, DevstralConfiguration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + setupFFNLayers(); + } + + // @formatter:off + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + weights.woLayered[layerIndex].asByteArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, state.wrapXb, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulQ8NonSquare, + context, + state.wrapXb, + state.wrapQ, state.wrapK, state.wrapV, + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + config.dim(), config.qDim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Use precomputed RoPE frequencies (YaRN-scaled) + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopyPrecomputed, + context, + state.positionHolder, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + weights.freq_cis_realFlat.asFloatArray(), + weights.freq_cis_imagFlat.asFloatArray(), + config.kvDim(), config.headSize(), layerIndex, config.contextLength()); + + configureAttention(unifiedLayer, layerIndex); + + // O projection: n=qDim (input), d=dim (output) + unifiedLayer.task("attn_output_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asByteArray(), + config.qDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fullyFusedRmsNormFFNGateUpQ8, + context, + state.wrapX, state.wrapHb, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray(), + config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asByteArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + weights.freq_cis_realFlat.asFloatArray(), + weights.freq_cis_imagFlat.asFloatArray()); + } else { + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder, + weights.freq_cis_realFlat.asFloatArray(), + weights.freq_cis_imagFlat.asFloatArray()); + } + return unifiedLayer; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int fusedQkvGlobal = (config.qDim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.qDim() / 2, 512); + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + + return tornadoForwardScheduler; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), + config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } + // @formatter:on +}