Skip to content

Commit 26b61db

Browse files
authored
Merge pull request #102 from orionpapadakis/feat/prefill-decode
Add Prefill–Decode Separation with Batched Prompt Ingestion and Logits Skipping
2 parents 214b58e + f65c0e8 commit 26b61db

25 files changed

Lines changed: 2729 additions & 195 deletions

.github/workflows/build-and-run.yml

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,49 @@ jobs:
9393
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
9494
tornado --version
9595
./mvnw clean package -DskipTests
96-
- name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf
96+
- name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Standard
9797
run: |
9898
cd ${{ github.workspace }}
9999
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
100100
./llama-tornado --gpu --${{ matrix.backend.name }} \
101101
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
102102
--prompt "Say hello"
103+
- name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode
104+
run: |
105+
cd ${{ github.workspace }}
106+
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
107+
./llama-tornado --gpu --${{ matrix.backend.name }} \
108+
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
109+
--prompt "Say hello" \
110+
--with-prefill-decode
111+
- name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode
112+
run: |
113+
cd ${{ github.workspace }}
114+
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
115+
./llama-tornado --gpu --${{ matrix.backend.name }} \
116+
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
117+
--prompt "Say hello" \
118+
--with-prefill-decode --batch-prefill-size 32
119+
- name: PTX- FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Prefill-Decode-CUDA-Graphs
120+
if: matrix.backend.name == 'ptx'
121+
run: |
122+
cd ${{ github.workspace }}
123+
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
124+
./llama-tornado --gpu --ptx \
125+
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
126+
--prompt "Say hello" \
127+
--with-prefill-decode \
128+
--cuda-graphs
129+
- name: PTX - FP16 - Run Llama-3.2-1B-Instruct-F16.gguf - Batch-Prefill-Decode-CUDA-Graphs
130+
if: matrix.backend.name == 'ptx'
131+
run: |
132+
cd ${{ github.workspace }}
133+
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
134+
./llama-tornado --gpu --ptx \
135+
--model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \
136+
--prompt "Say hello" \
137+
--with-prefill-decode --batch-prefill-size 32 \
138+
--cuda-graphs
103139
- name: FP16 - Run Qwen3-4B-f16.gguf
104140
run: |
105141
cd ${{ github.workspace }}

llama-tornado

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ class LlamaRunner:
8888
if args.verbose_init:
8989
cmd.append("-Dllama.EnableTimingForTornadoVMInit=true")
9090

91+
if args.with_prefill_decode or args.batch_prefill_size is not None:
92+
cmd.append("-Dllama.withPrefillDecode=true")
93+
94+
if args.batch_prefill_size is not None:
95+
cmd.append(f"-Dllama.prefillBatchSize={args.batch_prefill_size}")
96+
97+
if args.cuda_graphs:
98+
cmd.append("-Dllama.cudaGraphs=true")
99+
elif args.no_cuda_graphs:
100+
cmd.append("-Dllama.cudaGraphs=false")
101+
91102
# Debug options
92103
debug_config = []
93104

@@ -488,6 +499,43 @@ def create_parser() -> argparse.ArgumentParser:
488499
help="Execute the command after showing it (use with --show-command)",
489500
)
490501

502+
# Prefill/Decode optimization
503+
prefill_group = parser.add_argument_group("Prefill/Decode Optimization")
504+
prefill_group.add_argument(
505+
"--with-prefill-decode",
506+
dest="with_prefill_decode",
507+
action="store_true",
508+
help=(
509+
"Enable prefill/decode separation. "
510+
"Alone: sequential prefill (skip logits) + standard decode. "
511+
"With --batch-prefill-size N (N>1): batched GPU prefill via TornadoVMMasterPlanWithBatchPrefillDecode."
512+
),
513+
)
514+
prefill_group.add_argument(
515+
"--batch-prefill-size",
516+
dest="batch_prefill_size",
517+
type=int,
518+
default=None,
519+
metavar="N",
520+
help=(
521+
"Prefill chunk size (requires --with-prefill-decode). "
522+
"N=1: sequential prefill (same as --with-prefill-decode alone). "
523+
"N>1: batched prefill processing N tokens per chunk (llama.prefillBatchSize=N)."
524+
),
525+
)
526+
prefill_group.add_argument(
527+
"--cuda-graphs",
528+
dest="cuda_graphs",
529+
action="store_true",
530+
help="Enable CUDA graph capture/replay (llama.cudaGraphs=true); PTX backend only",
531+
)
532+
prefill_group.add_argument(
533+
"--no-cuda-graphs",
534+
dest="no_cuda_graphs",
535+
action="store_true",
536+
help="Disable CUDA graph capture/replay (llama.cudaGraphs=false); no-op, disabled by default",
537+
)
538+
491539
# Advanced options
492540
advanced_group = parser.add_argument_group("Advanced Options")
493541
advanced_group.add_argument(

src/main/java/org/beehive/gpullama3/Options.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,20 @@
55
import java.nio.file.Paths;
66

77
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo,
8-
boolean useTornadovm) {
8+
boolean useTornadovm, boolean withPrefillDecode, int batchPrefillSize) {
99

1010
public static final int DEFAULT_MAX_TOKENS = 1024;
1111

1212
public Options {
1313
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
1414
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
1515
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
16+
require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1");
17+
require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode");
18+
// Publish to system properties so TornadoVMMasterPlan and Llama read the right values
19+
// even when the JAR is invoked directly (without the Python launcher).
20+
if (withPrefillDecode) System.setProperty("llama.withPrefillDecode", "true");
21+
if (batchPrefillSize > 1) System.setProperty("llama.prefillBatchSize", String.valueOf(batchPrefillSize));
1622
}
1723

1824
static void require(boolean condition, String messageFormat, Object... args) {
@@ -44,6 +50,8 @@ public static void printUsage(PrintStream out) {
4450
out.println(" --max-tokens, -n <int> number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
4551
out.println(" --stream <boolean> print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
4652
out.println(" --echo <boolean> print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
53+
out.println(" --with-prefill-decode enable prefill/decode separation (skip logits during prefill)");
54+
out.println(" --batch-prefill-size <int> batched prefill chunk size; requires --with-prefill-decode, must be > 1, enables batched CPU/GPU prefill");
4755
out.println();
4856
}
4957

@@ -61,7 +69,7 @@ public static Options getDefaultOptions() {
6169
boolean echo = false;
6270
boolean useTornadoVM = getDefaultTornadoVM();
6371

64-
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM);
72+
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM, false, 1);
6573
}
6674

6775
public static Options parseOptions(String[] args) {
@@ -77,13 +85,16 @@ public static Options parseOptions(String[] args) {
7785
boolean stream = false;
7886
boolean echo = false;
7987
Boolean useTornadovm = null; // null means not specified via command line
88+
boolean withPrefillDecode = false;
89+
int batchPrefillSize = 1;
8090

8191
for (int i = 0; i < args.length; i++) {
8292
String optionName = args[i];
8393
require(optionName.startsWith("-"), "Invalid option %s", optionName);
8494
switch (optionName) {
8595
case "--interactive", "--chat", "-i" -> interactive = true;
8696
case "--instruct" -> interactive = false;
97+
case "--with-prefill-decode" -> withPrefillDecode = true;
8798
case "--help", "-h" -> {
8899
printUsage(System.out);
89100
System.exit(0);
@@ -111,6 +122,7 @@ public static Options parseOptions(String[] args) {
111122
case "--stream" -> stream = Boolean.parseBoolean(nextArg);
112123
case "--echo" -> echo = Boolean.parseBoolean(nextArg);
113124
case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg);
125+
case "--batch-prefill-size" -> batchPrefillSize = Integer.parseInt(nextArg);
114126
default -> require(false, "Unknown option: %s", optionName);
115127
}
116128
}
@@ -123,6 +135,6 @@ public static Options parseOptions(String[] args) {
123135
useTornadovm = getDefaultTornadoVM();
124136
}
125137

126-
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm);
138+
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm, withPrefillDecode, batchPrefillSize);
127139
}
128140
}
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
package org.beehive.gpullama3.inference;
2+
3+
import org.beehive.gpullama3.auxiliary.Parallel;
4+
import org.beehive.gpullama3.inference.state.State;
5+
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
6+
import org.beehive.gpullama3.model.Configuration;
7+
import org.beehive.gpullama3.model.Model;
8+
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
9+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
10+
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode;
11+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
12+
13+
/**
14+
* Low-level forward passes for the batched prefill/decode inference path (Phase 3/4).
15+
*
16+
* <p>Parallel to {@link InferenceCoreWithPrefillDecode} — does NOT modify it.</p>
17+
*
18+
* <p>Provides three operations:</p>
19+
* <ul>
20+
* <li>{@link #batchForwardJavaPrefill} — CPU batch prefill: processes a chunk of
21+
* prompt tokens in one pass using batch matmul, avoiding redundant weight
22+
* traversals. Only the KV cache is populated; logits are intentionally omitted.</li>
23+
* <li>{@link #batchForwardTornadoVMPrefill} — GPU batch prefill: delegates the chunk
24+
* to {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}.</li>
25+
* <li>{@link #forwardTornadoVMDecode} — GPU decode: delegates a single decode step to
26+
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, which
27+
* handles the embedding copy and runs the full decode + logits graphs.</li>
28+
* </ul>
29+
*/
30+
public final class InferenceCoreBatchPrefillDecode {
31+
32+
private InferenceCoreBatchPrefillDecode() {}
33+
34+
/**
35+
* CPU batched prefill forward pass for LLaMA (Phase 3).
36+
*
37+
* <p>Processes {@code batchSize} prompt tokens simultaneously through all
38+
* transformer layers. For each layer, Q/K/V projections, output projection,
39+
* and FFN projections are computed via batch matmul
40+
* ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}),
41+
* which parallelises over both output dimension and batch simultaneously.
42+
* Attention reuses {@code state.att} sequentially per token (parallel per
43+
* head within each token), keeping memory overhead minimal.</p>
44+
*
45+
* <p>The logits layer is intentionally omitted — only the KV cache matters
46+
* for prefill positions.</p>
47+
*
48+
* @param model the LLaMA model (must carry {@link StandardWeights})
49+
* @param state mutable inference state (KV cache, att buffer …)
50+
* @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b}
51+
* @param startPos sequence position of {@code tokens[0]}
52+
* @param batchSize number of tokens in this chunk ({@code tokens.length})
53+
*/
54+
public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) {
55+
final Configuration config = model.configuration();
56+
final StandardWeights weights = (StandardWeights) model.weights();
57+
int dim = config.dim();
58+
int headSize = config.headSize();
59+
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
60+
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads();
61+
float sqrtHeadSize = (float) Math.sqrt(headSize);
62+
63+
// ── Batch activation tensors (allocated once per chunk) ───────────────
64+
FloatTensor[] x = new FloatTensor[batchSize];
65+
FloatTensor[] xb = new FloatTensor[batchSize];
66+
FloatTensor[] xb2 = new FloatTensor[batchSize];
67+
FloatTensor[] q = new FloatTensor[batchSize];
68+
FloatTensor[] k = new FloatTensor[batchSize];
69+
FloatTensor[] v = new FloatTensor[batchSize];
70+
FloatTensor[] hb = new FloatTensor[batchSize];
71+
FloatTensor[] hb2 = new FloatTensor[batchSize];
72+
for (int b = 0; b < batchSize; b++) {
73+
x[b] = ArrayFloatTensor.allocate(dim);
74+
xb[b] = ArrayFloatTensor.allocate(dim);
75+
xb2[b] = ArrayFloatTensor.allocate(dim);
76+
q[b] = ArrayFloatTensor.allocate(dim);
77+
k[b] = ArrayFloatTensor.allocate(kvDim);
78+
v[b] = ArrayFloatTensor.allocate(kvDim);
79+
hb[b] = ArrayFloatTensor.allocate(config.hiddenDim());
80+
hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim());
81+
}
82+
83+
// ── Token embeddings ──────────────────────────────────────────────────
84+
Parallel.parallelFor(0, batchSize, b ->
85+
weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim));
86+
87+
// ── Transformer layers ────────────────────────────────────────────────
88+
for (int l = 0; l < config.numberOfLayers(); l++) {
89+
final int layer = l;
90+
91+
Parallel.parallelFor(0, batchSize, b ->
92+
InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps()));
93+
94+
weights.wq[l].matmul(batchSize, xb, q, dim, dim);
95+
weights.wk[l].matmul(batchSize, xb, k, kvDim, dim);
96+
weights.wv[l].matmul(batchSize, xb, v, kvDim, dim);
97+
98+
Parallel.parallelFor(0, batchSize, b -> {
99+
int pos = startPos + b;
100+
for (int i = 0; i < dim; i += 2) {
101+
int head_dim = i % headSize;
102+
float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2));
103+
float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2));
104+
int rotn = i < kvDim ? 2 : 1;
105+
for (int vv = 0; vv < rotn; vv++) {
106+
FloatTensor vec = vv == 0 ? q[b] : k[b];
107+
float v0 = vec.getFloat(i);
108+
float v1 = vec.getFloat(i + 1);
109+
vec.setFloat(i, v0 * fcr - v1 * fci);
110+
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
111+
}
112+
}
113+
k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim);
114+
v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim);
115+
});
116+
117+
for (int b = 0; b < batchSize; b++) {
118+
final int pos_b = startPos + b;
119+
final int bFinal = b;
120+
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
121+
int qOffset = h * headSize;
122+
int attOffset = h * config.contextLength();
123+
124+
for (int t = 0; t <= pos_b; t++) {
125+
int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
126+
float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize;
127+
state.att.setFloat(attOffset + t, score);
128+
}
129+
state.att.softmaxInPlace(attOffset, pos_b + 1);
130+
131+
int xbOffset = h * headSize;
132+
xb[bFinal].fillInPlace(xbOffset, headSize, 0f);
133+
for (int t = 0; t <= pos_b; t++) {
134+
int vOffset = t * kvDim + (h / kvMul) * headSize;
135+
float a = state.att.getFloat(attOffset + t);
136+
xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a);
137+
}
138+
});
139+
}
140+
141+
weights.wo[l].matmul(batchSize, xb, xb2, dim, dim);
142+
143+
Parallel.parallelFor(0, batchSize, b -> {
144+
x[b].addInPlace(xb2[b]);
145+
InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps());
146+
});
147+
148+
weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim);
149+
weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim);
150+
151+
Parallel.parallelFor(0, batchSize, b -> {
152+
hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
153+
hb[b].multiplyInPlace(hb2[b]);
154+
});
155+
156+
weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim());
157+
158+
Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b]));
159+
}
160+
// Final RMSNorm and vocab projection intentionally omitted —
161+
// logits are not needed for any token in a prefill batch.
162+
}
163+
164+
/**
165+
* GPU batched prefill forward pass (Phase 4).
166+
*
167+
* <p>Delegates the full chunk to
168+
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill},
169+
* which handles embedding lookup and GPU execution internally.</p>
170+
*
171+
* @param model the LLaMA model
172+
* @param tokens token ids for this chunk
173+
* @param startPos sequence position of {@code tokens[0]}
174+
* @param chunkSize number of tokens in this chunk
175+
* @param plan the batched prefill/decode GPU plan
176+
*/
177+
public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize,
178+
TornadoVMMasterPlanWithBatchPrefillDecode plan) {
179+
plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize);
180+
}
181+
182+
/**
183+
* GPU decode forward pass (Phase 4).
184+
*
185+
* <p>Delegates a single-token decode step to
186+
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode},
187+
* which copies the token embedding and runs the decode + logits graphs.</p>
188+
*
189+
* @param model the LLaMA model
190+
* @param token current token id
191+
* @param position sequence position
192+
* @param plan the batched prefill/decode GPU plan
193+
* @return logits array for token sampling
194+
*/
195+
public static FloatArray forwardTornadoVMDecode(Model model, int token, int position,
196+
TornadoVMMasterPlanWithBatchPrefillDecode plan) {
197+
return plan.tornadoVMForwardDecode(token, position, model);
198+
}
199+
}

0 commit comments

Comments
 (0)