Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
Expand Down Expand Up @@ -179,6 +180,95 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
return state.logits;
}

/**
* Forward pass for Devstral 2 models where head_dim != dim/num_heads.
* Q projection outputs qDim (num_heads * head_dim) instead of dim.
*/
public static FloatTensor forwardJavaDevstral(Model model, State state, int token, int position) {
final DevstralConfiguration config = (DevstralConfiguration) model.configuration();
final StandardWeights weights = (StandardWeights) model.weights();
int dim = config.dim();
int headSize = config.headSize(); // 128 (independent head_dim)
int qDim = config.qDim(); // 4096 = 32 * 128
int kvDim = config.kvDim(); // 1024 = 8 * 128
int kvMul = config.kvMul();
float sqrtHeadSize = (float) Math.sqrt(headSize);

weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);

for (int l = 0; l < config.numberOfLayers(); l++) {
rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps());

weights.wq[l].matmul(state.xb, state.q, qDim, dim);
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);

// RoPE over qDim (not dim)
for (int i = 0; i < qDim; i += 2) {
int head_dim = i % headSize;
float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2));
float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2));
int rotn = i < kvDim ? 2 : 1;
for (int v = 0; v < rotn; v++) {
FloatTensor vec = v == 0 ? state.q : state.k;
float v0 = vec.getFloat(i);
float v1 = vec.getFloat(i + 1);
vec.setFloat(i, v0 * fcr - v1 * fci);
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
}
}

state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);

int curLayer = l;

Parallel.parallelFor(0, config.numberOfHeads(), h -> {
int qOffset = h * headSize;
int attOffset = h * config.contextLength();

for (int t = 0; t <= position; t++) {
int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
score /= sqrtHeadSize;
state.att.setFloat(attOffset + t, score);
}

state.att.softmaxInPlace(attOffset, position + 1);

int xbOffset = h * headSize;
state.xb.fillInPlace(xbOffset, headSize, 0f);

for (int t = 0; t <= position; t++) {
int vOffset = t * kvDim + (h / kvMul) * headSize;
float a = state.att.getFloat(attOffset + t);
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
}
});

// O projection: input qDim, output dim
weights.wo[l].matmul(state.xb, state.xb2, dim, qDim);

state.x.addInPlace(state.xb2);

rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps());

weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);

state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
state.hb.multiplyInPlace(state.hb2);

weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());
state.x.addInPlace(state.xb);
}

rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);

return state.logits;
}

public static FloatTensor forwardJavaQwen2(Model model, State state, int token, int position) {
final Qwen2Configuration config = (Qwen2Configuration) model.configuration();
final Qwen2StandardWeights weights = (Qwen2StandardWeights) model.weights();
Expand Down
46 changes: 46 additions & 0 deletions src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,50 @@ public static Pair<float[], float[]> precomputeFreqsCis(int contextLength, int h
assert contextLength * (headSize / 2) == n;
return new Pair<>(cr, ci);
}

public static Pair<float[], float[]> precomputeFreqsCisYaRN(int contextLength, int headSize, double theta,
float factor, float betaFast, float betaSlow, float logMultiplier, int originalContextLength) {
assert headSize % 2 == 0;
float[] cr = new float[contextLength * (headSize / 2)];
float[] ci = new float[contextLength * (headSize / 2)];

float freqScale = 1.0f / factor;

// Compute correlation dimensions for ramp interpolation
float corrDim0 = yarnCorrDim(headSize, originalContextLength, betaFast, (float) theta);
float corrDim1 = yarnCorrDim(headSize, originalContextLength, betaSlow, (float) theta);

// Compute mscale (attention scaling for extended context)
// Formula: mscale = 0.1 * logMultiplier * log(factor) + 1.0
float mscale = logMultiplier > 0
? 1.0f + 0.1f * logMultiplier * (float) Math.log(1.0f / freqScale)
: 1.0f;

int n = 0;
for (int pos = 0; pos < contextLength; ++pos) {
for (int i = 0; i < headSize; i += 2) {
float freqExtrap = (float) (1.0 / Math.pow(theta, i / (double) headSize));
float freqInterp = freqScale * freqExtrap;

float rampMix = yarnRamp(corrDim0, corrDim1, i / 2);
float freq = freqInterp * (1.0f - rampMix) + freqExtrap * rampMix;

float val = pos * freq;
cr[n] = (float) Math.cos(val) * mscale;
ci[n] = (float) Math.sin(val) * mscale;
n++;
}
}
assert contextLength * (headSize / 2) == n;
return new Pair<>(cr, ci);
}

private static float yarnCorrDim(int nDims, int nCtxOrig, float nRot, float base) {
return nDims * (float) Math.log(nCtxOrig / (nRot * 2.0f * (float) Math.PI)) / (2.0f * (float) Math.log(base));
}

private static float yarnRamp(float low, float high, int i0) {
float y = (i0 - low) / Math.max(0.001f, high - low);
return 1.0f - Math.min(1.0f, Math.max(0.0f, y));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package org.beehive.gpullama3.inference.state;

import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;

import java.util.stream.Stream;

/**
* State for Devstral 2 models where head_dim != dim/num_heads.
* Allocates Q with qDim (num_heads * head_dim) and K/V with kvDim (num_kv_heads * head_dim).
*/
public final class DevstralState extends State {

public DevstralState(Configuration config, int batchsize) {
super(config, batchsize);
}

@Override
protected StateFields createStateFields(Configuration config) {
DevstralConfiguration dc = (DevstralConfiguration) config;
StateFields fields = new StateFields();

int qDim = dc.qDim();
int kvDim = dc.kvDim();

fields.x = ArrayFloatTensor.allocate(dc.dim());
fields.xb = ArrayFloatTensor.allocate(dc.dim());
fields.xb2 = ArrayFloatTensor.allocate(dc.dim());
fields.hb = ArrayFloatTensor.allocate(dc.hiddenDim());
fields.hb2 = ArrayFloatTensor.allocate(dc.hiddenDim());
fields.q = ArrayFloatTensor.allocate(qDim);
fields.k = ArrayFloatTensor.allocate(kvDim);
fields.v = ArrayFloatTensor.allocate(kvDim);
fields.att = ArrayFloatTensor.allocate(dc.numberOfHeads(), dc.contextLength());
fields.logits = ArrayFloatTensor.allocate(dc.vocabularySize());

fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(dc.contextLength(), kvDim)).limit(dc.numberOfLayers()).toArray(FloatTensor[]::new);
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(dc.contextLength(), kvDim)).limit(dc.numberOfLayers()).toArray(FloatTensor[]::new);

// TornadoVM wrappers
fields.wrapX = new FloatArray(dc.dim());
fields.wrapXb = new FloatArray(dc.dim());
fields.wrapXb2 = new FloatArray(dc.dim());
fields.wrapHb = new FloatArray(dc.hiddenDim());
fields.wrapHb2 = new FloatArray(dc.hiddenDim());

switch (dc.quantization()) {
case "FP16" -> fields.createActivationFP16(dc.dim());
case "Q8_0" -> fields.createActivationQ8_0(dc.dim());
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + dc.quantization());
}
fields.wrapLogits = new FloatArray(dc.vocabularySize());
fields.wrapQ = new FloatArray(qDim);
fields.wrapK = new FloatArray(kvDim);
fields.wrapV = new FloatArray(kvDim);

fields.wrapXFP16 = new HalfFloatArray(dc.dim());
fields.wrapXbFP16 = new HalfFloatArray(dc.dim());
fields.wrapKeyCache = new FloatArray(dc.contextLength() * kvDim * dc.numberOfLayers());
fields.wrapValueCache = new FloatArray(dc.contextLength() * kvDim * dc.numberOfLayers());
fields.wrapValueCache.init(0.f);
fields.wrapKeyCache.init(0.f);
fields.wrapAtt = new FloatArray(dc.numberOfHeads() * dc.contextLength());
fields.positionHolder = new IntArray(1);

fields.temp = new FloatArray(1 + ((dc.dim() + localSize - 1) / localSize));
fields.tempFFN = new FloatArray(1 + ((dc.dim() + localSize - 1) / localSize));
fields.tempLogits = new FloatArray(1 + ((dc.dim() + localSize - 1) / localSize));

return fields;
}
}
8 changes: 8 additions & 0 deletions src/main/java/org/beehive/gpullama3/model/ModelType.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.beehive.gpullama3.model;

import org.beehive.gpullama3.model.loader.DevstralModelLoader;
import org.beehive.gpullama3.model.loader.GraniteLoader;
import org.beehive.gpullama3.tensor.GGUF;
import org.beehive.gpullama3.model.loader.LlamaModelLoader;
Expand Down Expand Up @@ -37,6 +38,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo
}
},

DEVSTRAL_2 {
@Override
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
return new DevstralModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel();
}
},

QWEN_2 {
@Override
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
Expand Down
73 changes: 73 additions & 0 deletions src/main/java/org/beehive/gpullama3/model/devstral/Devstral.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package org.beehive.gpullama3.model.devstral;

import org.beehive.gpullama3.inference.InferenceCore;
import org.beehive.gpullama3.inference.InferenceEngine;
import org.beehive.gpullama3.inference.sampler.Sampler;
import org.beehive.gpullama3.inference.state.DevstralState;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.AbstractModel;
import org.beehive.gpullama3.model.ModelType;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.tokenizer.DevstralTokenizer;
import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;

import java.util.List;
import java.util.Set;
import java.util.function.IntConsumer;

public class Devstral extends AbstractModel {

DevstralConfiguration configuration;

public Devstral(DevstralConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
super(tokenizer, weights, chatFormat, null);
this.configuration = configuration;
}

@Override
public DevstralConfiguration configuration() {
return configuration;
}

@Override
public DevstralTokenizer tokenizer() {
return (DevstralTokenizer) tokenizer;
}

@Override
public ModelType getModelType() {
return ModelType.DEVSTRAL_2;
}

public State createNewState() {
State state = new DevstralState(configuration(), -1);
state.latestToken = tokenizer.getSpecialTokens().get("<s>");
return state;
}

public State createNewState(int batchsize) {
State state = new DevstralState(configuration(), batchsize);
state.latestToken = tokenizer.getSpecialTokens().get("<s>");
return state;
}

@Override
public void forward(State state, int token, int position) {
InferenceCore.forwardJavaDevstral(this, state, token, position);
}

@Override
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
}

@Override
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.beehive.gpullama3.model.devstral;

import org.beehive.gpullama3.model.Configuration;

/**
* Configuration for Devstral 2 models (Mistral 3 architecture).
* Unlike standard Mistral, Devstral 2 has an independent head dimension
* (head_dim != dim / num_heads), requiring explicit key_length/value_length.
*/
// @formatter:off
public record DevstralConfiguration(String quantization,
int dim,
int hiddenDim,
int numberOfLayers,
int numberOfHeads,
int numberOfKeyValueHeads,
int headDim,
int vocabularySize,
int contextLength,
float rmsNormEps,
float ropeTheta) implements Configuration {

@Override public String quantization() {
return quantization;
}

/**
* Q projection output dimension = numberOfHeads * headDim.
* This differs from dim when headDim != dim/numberOfHeads.
*/
public int qDim() {
return numberOfHeads * headDim;
}

public int kvDim() {
return numberOfKeyValueHeads * headDim;
}

public int kvMul() {
return numberOfHeads / numberOfKeyValueHeads;
}

@Override
public int numberOfHeadsKey() {
throw new UnsupportedOperationException("Not supported for Devstral.");
}

@Override
public int contextLengthModel() {
throw new UnsupportedOperationException("Not supported for Devstral.");
}

public int headSize() {
return headDim;
}
}
Loading
Loading