Skip to content

Commit a73a5ff

Browse files
authored
Merge pull request #92 from beehive-lab/feat/models/ibm-granite
[models] Support for IBM Granite Models 3.2, 3.3 & 4.0 with FP16 and Q8
2 parents 5383440 + 59eb425 commit a73a5ff

26 files changed

+2834
-8
lines changed

.github/workflows/build-and-run.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ jobs:
128128
./llama-tornado --gpu --${{ matrix.backend.name }} \
129129
--model /$MODELS_DIR/Phi-3-mini-4k-instruct-fp16.gguf \
130130
--prompt "Say hello"
131+
- name: FP16 - Run Granite-3.2-2b-instruct-f16.gguf
132+
run: |
133+
cd ${{ github.workspace }}
134+
export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH"
135+
./llama-tornado --gpu --${{ matrix.backend.name }} \
136+
--model /$MODELS_DIR/granite-3.2-2b-instruct-f16.gguf \
137+
--prompt "Say hello"
138+
- name: FP16 - Run Granite-4.0-1b-F16.gguf
139+
run: |
140+
cd ${{ github.workspace }}
141+
export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH"
142+
./llama-tornado --gpu --${{ matrix.backend.name }} \
143+
--model /$MODELS_DIR/granite-4.0-1b-F16.gguf \
144+
--prompt "Say hello"
131145
- name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf
132146
run: |
133147
cd ${{ github.workspace }}
@@ -163,3 +177,18 @@ jobs:
163177
./llama-tornado --gpu --${{ matrix.backend.name }} \
164178
--model $MODELS_DIR/Mistral-7B-Instruct-v0.3.Q8_0.gguf \
165179
--prompt "Say hello"
180+
- name: Q8 - Run Granite-3.2-2b-instruct-Q8.gguf
181+
run: |
182+
cd ${{ github.workspace }}
183+
export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH"
184+
./llama-tornado --gpu --${{ matrix.backend.name }} \
185+
--model /$MODELS_DIR/granite-3.2-2b-instruct-Q8_0.gguf \
186+
--prompt "Say hello"
187+
- name: Q8 - Run Granite-4.0-1b-Q8_0.gguf
188+
run: |
189+
cd ${{ github.workspace }}
190+
export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH"
191+
./llama-tornado --gpu --${{ matrix.backend.name }} \
192+
--model /$MODELS_DIR/granite-4.0-1b-Q8_0.gguf \
193+
--prompt "Say hello"
194+

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ clean:
1515
$(MVN) clean
1616

1717
install:
18-
$(MVN) install -DskipTests
18+
$(MVN) install -DskipTests
1919

2020
# Package the project without running tests
2121
package:

README.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
<strong>Llama3</strong> models written in <strong>native Java</strong> automatically accelerated on GPUs with <a href="https://github.com/beehive-lab/TornadoVM" target="_blank"><strong>TornadoVM</strong></a>.
2020
Runs Llama3 inference efficiently using TornadoVM's GPU acceleration.
2121
<br><br>
22-
Currently, supports <strong>Llama3</strong>, <strong>Mistral</strong>, <strong>Qwen2.5</strong>, <strong>Qwen3</strong> and <strong>Phi3</strong> models in the GGUF format.
22+
Currently, supports <strong>Llama3</strong>, <strong>Mistral</strong>, <strong>Qwen2.5</strong>, <strong>Qwen3</strong>, <strong>Phi-3</strong>, <strong> IBM Granite 3.2+ </strong> and <strong> IBM Granite 4.0 </strong> models in the GGUF format.
2323
Also, it is used as GPU inference engine in
2424
<a href="https://docs.quarkiverse.io/quarkus-langchain4j/dev/gpullama3-chat-model.html" target="_blank">Quarkus</a>
2525
and
@@ -89,7 +89,7 @@ All pre-built SDKs are available on the TornadoVM [Releases Page](https://github
8989
wget https://github.com/beehive-lab/TornadoVM/releases/download/v2.1.0/tornadovm-2.1.0-opencl-linux-amd64.zip
9090
unzip tornadovm-2.1.0-opencl-linux-amd64.zip
9191
# Replace <path-to-sdk> manually with the absolute path of the extracted folder
92-
export TORNADO_SDK="<path-to-sdk>/tornadovm-2.1.0-opencl"
92+
export TORNADOVM_HOME="<path-to-sdk>/tornadovm-2.1.0-opencl"
9393
export PATH=$TORNADO_SDK/bin:$PATH
9494

9595
tornado --devices
@@ -102,7 +102,7 @@ tornado --version
102102
wget https://github.com/beehive-lab/TornadoVM/releases/download/v2.1.0/tornadovm-2.1.0-opencl-mac-aarch64.zip
103103
unzip tornadovm-2.1.0-opencl-mac-aarch64.zip
104104
# Replace <path-to-sdk> manually with the absolute path of the extracted folder
105-
export TORNADO_SDK="<path-to-sdk>/tornadovm-2.1.0-opencl"
105+
export TORNADOVM_HOME="<path-to-sdk>/tornadovm-2.1.0-opencl"
106106
export PATH=$TORNADO_SDK/bin:$PATH
107107

108108
tornado --devices
@@ -251,7 +251,7 @@ You can run llama-tornado as a pure Java script using [JBang](https://www.jbang.
251251
### Prerequisites for JBang
252252

253253
1. **Install JBang**: Follow the [JBang installation guide](https://www.jbang.dev/download/)
254-
2. **TornadoVM SDK**: You still need TornadoVM installed and `TORNADO_SDK` environment variable set (see Setup section above)
254+
2. **TornadoVM SDK**: You still need TornadoVM installed and `TORNADOVM_HOME` environment variable set (see Setup section above)
255255

256256
### Quick Start with JBang
257257

@@ -295,6 +295,13 @@ jbang LlamaTornadoCli.java -m beehive-llama-3.2-1b-instruct-fp16.gguf \
295295
### Llama3.2 Collection
296296
[https://huggingface.co/collections/beehive-lab/llama3-gpullama3java](https://huggingface.co/collections/beehive-lab/llama3-gpullama3java)
297297

298+
### IBM Granite 4.0 Collection
299+
[https://huggingface.co/collections/beehive-lab/granite-40-language-models-gpullama3java](https://huggingface.co/collections/beehive-lab/granite-40-language-models-gpullama3java)
300+
301+
302+
### IBM Granite 3.3 Collection
303+
[https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java](https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java)
304+
298305
### Qwen 2.5 Collection
299306
[https://huggingface.co/collections/beehive-lab/qwen-25-gpullama3java](https://huggingface.co/collections/beehive-lab/qwen-25-gpullama3java)
300307

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
1212
import org.beehive.gpullama3.model.Configuration;
1313
import org.beehive.gpullama3.model.Model;
14+
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
1415
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
1516
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
1617
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
@@ -546,6 +547,127 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke
546547
return state.logits;
547548
}
548549

550+
/**
551+
* Forward pass for Granite models with µP scaling factors applied.
552+
* <p>
553+
* Granite uses the same transformer architecture as Llama but with maximal update parameterization (µP)
554+
* scaling factors applied at specific points:
555+
* <ul>
556+
* <li>Embedding scaling: multiply embeddings after lookup</li>
557+
* <li>Attention scaling: use custom multiplier instead of 1/sqrt(headDim)</li>
558+
* <li>Residual scaling: multiply residual connections</li>
559+
* <li>Logit scaling: divide logits by the scaling factor</li>
560+
* </ul>
561+
*/
562+
public static FloatTensor forwardGranite(Model model, State state, int token, int position) {
563+
final GraniteConfiguration config = (GraniteConfiguration) model.configuration();
564+
final StandardWeights weights = (StandardWeights) model.weights();
565+
int dim = config.dim();
566+
int headSize = config.headSize();
567+
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
568+
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads();
569+
float attentionScale = config.attentionScale();
570+
float residualScale = config.residualScale();
571+
float embeddingScale = config.embeddingScale();
572+
float logitScale = config.logitScale();
573+
574+
// copy the token embedding into x
575+
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
576+
// Apply Granite embedding scaling
577+
state.x.mapInPlace(v -> v * embeddingScale);
578+
579+
// forward all the layers
580+
for (int l = 0; l < config.numberOfLayers(); l++) {
581+
// attention rmsnorm
582+
rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps());
583+
584+
// qkv matmuls for this position
585+
weights.wq[l].matmul(state.xb, state.q, dim, dim);
586+
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
587+
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);
588+
589+
// RoPE relative positional encoding
590+
for (int i = 0; i < dim; i += 2) {
591+
int head_dim = i % headSize;
592+
float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2));
593+
float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2));
594+
int rotn = i < kvDim ? 2 : 1;
595+
for (int v = 0; v < rotn; v++) {
596+
FloatTensor vec = v == 0 ? state.q : state.k;
597+
float v0 = vec.getFloat(i);
598+
float v1 = vec.getFloat(i + 1);
599+
vec.setFloat(i, v0 * fcr - v1 * fci);
600+
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
601+
}
602+
}
603+
604+
// save key,value at this time step to kv cache
605+
state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
606+
state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);
607+
608+
int curLayer = l;
609+
610+
// multihead attention with Granite attention scaling
611+
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
612+
int qOffset = h * headSize;
613+
int attOffset = h * config.contextLength();
614+
615+
for (int t = 0; t <= position; t++) {
616+
int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
617+
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
618+
// Granite uses custom attention multiplier instead of 1/sqrt(headSize)
619+
score *= attentionScale;
620+
state.att.setFloat(attOffset + t, score);
621+
}
622+
623+
state.att.softmaxInPlace(attOffset, position + 1);
624+
625+
int xbOffset = h * headSize;
626+
state.xb.fillInPlace(xbOffset, headSize, 0f);
627+
628+
for (int t = 0; t <= position; t++) {
629+
int vOffset = t * kvDim + (h / kvMul) * headSize;
630+
float a = state.att.getFloat(attOffset + t);
631+
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
632+
}
633+
});
634+
635+
// final matmul to get the output of the attention
636+
weights.wo[l].matmul(state.xb, state.xb2, dim, dim);
637+
638+
// residual connection with Granite scaling
639+
state.xb2.mapInPlace(v -> v * residualScale);
640+
state.x.addInPlace(state.xb2);
641+
642+
// ffn rmsnorm
643+
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps());
644+
645+
// FFN: self.w2(F.silu(self.w1(x)) * self.w3(x))
646+
weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
647+
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);
648+
649+
// SwiGLU non-linearity
650+
state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
651+
state.hb.multiplyInPlace(state.hb2);
652+
653+
// final matmul to get the output of the ffn
654+
weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());
655+
656+
// residual connection with Granite scaling
657+
state.xb.mapInPlace(v -> v * residualScale);
658+
state.x.addInPlace(state.xb);
659+
}
660+
661+
rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());
662+
663+
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);
664+
665+
// Apply Granite logit scaling (divide by the scaling factor)
666+
state.logits.mapInPlace(v -> v * logitScale);
667+
668+
return state.logits;
669+
}
670+
549671
static void copyChunk(FloatTensor in, FloatTensor out, int dim1In, int dim1Out, int nChunks, int chunkNo) {
550672
assert (dim1In == dim1Out * nChunks);
551673
final int startOffsetInDim1 = chunkNo * dim1Out;

src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,4 +531,138 @@ public static List<Integer> generateTokensGPUPhi3(Model model, State state, int
531531

532532
return generatedTokens;
533533
}
534+
535+
/**
536+
* Generates tokens using the Granite model with CPU inference.
537+
* Identical pattern to generateTokensLlama but calls forwardGranite.
538+
*/
539+
public static List<Integer> generateTokensGranite(Model model, State state, int startPosition,
540+
List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
541+
IntConsumer onTokenGenerated) {
542+
long startNanos = System.nanoTime();
543+
long inferenceStartNanos = 0;
544+
545+
Object logits;
546+
if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) {
547+
maxTokens = model.configuration().contextLength();
548+
}
549+
550+
List<Integer> generatedTokens = new ArrayList<>();
551+
552+
int currentToken = state.latestToken;
553+
int nextToken;
554+
int promptIndex = 0;
555+
int pos = startPosition;
556+
557+
while (pos < maxTokens) {
558+
// Call Granite-specific forward pass
559+
logits = InferenceCore.forwardGranite(model, state, currentToken, pos);
560+
561+
if (promptIndex < promptTokens.size()) {
562+
nextToken = promptTokens.get(promptIndex++);
563+
if (echo) {
564+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
565+
}
566+
} else {
567+
if (inferenceStartNanos == 0) {
568+
inferenceStartNanos = System.nanoTime();
569+
}
570+
571+
nextToken = sampler.sampleToken(logits);
572+
573+
if (echo) {
574+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
575+
}
576+
577+
generatedTokens.add(nextToken);
578+
579+
if (onTokenGenerated != null) {
580+
onTokenGenerated.accept(nextToken);
581+
}
582+
583+
if (stopTokens.contains(nextToken)) {
584+
break;
585+
}
586+
}
587+
588+
currentToken = nextToken;
589+
state.latestToken = currentToken;
590+
pos++;
591+
}
592+
593+
long endNanos = System.nanoTime();
594+
double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0;
595+
int totalTokens = promptIndex + generatedTokens.size();
596+
597+
LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds);
598+
599+
return generatedTokens;
600+
}
601+
602+
/**
603+
* Generates tokens using the Granite model with GPU (TornadoVM) inference.
604+
* Identical pattern to generateTokensGPULlama.
605+
*/
606+
public static List<Integer> generateTokensGPUGranite(Model model, State state, int startPosition,
607+
List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
608+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMMasterPlan) {
609+
long startNanos = System.nanoTime();
610+
long inferenceStartNanos = 0;
611+
612+
Object logits;
613+
if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) {
614+
maxTokens = model.configuration().contextLength();
615+
}
616+
617+
List<Integer> generatedTokens = new ArrayList<>();
618+
619+
int currentToken = state.latestToken;
620+
int nextToken;
621+
int promptIndex = 0;
622+
int pos = startPosition;
623+
624+
while (pos < maxTokens) {
625+
// Call TornadoVM forward pass (same as Llama for now)
626+
logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMMasterPlan);
627+
628+
if (promptIndex < promptTokens.size()) {
629+
nextToken = promptTokens.get(promptIndex++);
630+
if (echo) {
631+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
632+
}
633+
} else {
634+
if (inferenceStartNanos == 0) {
635+
inferenceStartNanos = System.nanoTime();
636+
}
637+
638+
nextToken = sampler.sampleToken(logits);
639+
640+
if (echo) {
641+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
642+
}
643+
644+
generatedTokens.add(nextToken);
645+
646+
if (onTokenGenerated != null) {
647+
onTokenGenerated.accept(nextToken);
648+
}
649+
650+
if (stopTokens.contains(nextToken)) {
651+
break;
652+
}
653+
}
654+
655+
currentToken = nextToken;
656+
state.latestToken = currentToken;
657+
pos++;
658+
}
659+
660+
long endNanos = System.nanoTime();
661+
double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0;
662+
int totalTokens = promptIndex + generatedTokens.size();
663+
664+
LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds);
665+
666+
return generatedTokens;
667+
}
534668
}

0 commit comments

Comments
 (0)