Skip to content

Commit 9aff199

Browse files
[prf/dec] Provide distinct support for standard, prefill-decode and batched-prefill-decode execution paths for both CPU and GPU
1 parent 04dcd8e commit 9aff199

6 files changed

Lines changed: 93 additions & 42 deletions

File tree

llama-tornado

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ class LlamaRunner:
8787
if args.verbose_init:
8888
cmd.append("-Dllama.EnableTimingForTornadoVMInit=true")
8989

90-
if args.batched_prefill:
91-
cmd.append("-Dllama.batchedPrefill=true")
90+
if args.with_prefill_decode or args.batch_prefill_size is not None:
91+
cmd.append("-Dllama.withPrefillDecode=true")
9292

93-
if args.prefill_batch_size is not None:
94-
cmd.append(f"-Dllama.prefillBatchSize={args.prefill_batch_size}")
93+
if args.batch_prefill_size is not None:
94+
cmd.append(f"-Dllama.prefillBatchSize={args.batch_prefill_size}")
9595

9696
if args.no_cuda_graphs:
9797
cmd.append("-Dllama.cudaGraphs=false")
@@ -484,17 +484,26 @@ def create_parser() -> argparse.ArgumentParser:
484484
# Prefill/Decode optimization
485485
prefill_group = parser.add_argument_group("Prefill/Decode Optimization")
486486
prefill_group.add_argument(
487-
"--batched-prefill",
488-
dest="batched_prefill",
487+
"--with-prefill-decode",
488+
dest="with_prefill_decode",
489489
action="store_true",
490-
help="Enable batched prefill/decode separation (llama.batchedPrefill=true)",
490+
help=(
491+
"Enable prefill/decode separation. "
492+
"Alone: sequential prefill (skip logits) + standard decode. "
493+
"With --batch-prefill-size N (N>1): batched GPU prefill via TornadoVMMasterPlanWithBatchPrefillDecode."
494+
),
491495
)
492496
prefill_group.add_argument(
493-
"--prefill-batch-size",
494-
dest="prefill_batch_size",
497+
"--batch-prefill-size",
498+
dest="batch_prefill_size",
495499
type=int,
496500
default=None,
497-
help="Prefill chunk/batch size (llama.prefillBatchSize=N, default: 32)",
501+
metavar="N",
502+
help=(
503+
"Prefill chunk size (requires --with-prefill-decode). "
504+
"N=1: sequential prefill (same as --with-prefill-decode alone). "
505+
"N>1: batched prefill processing N tokens per chunk (llama.prefillBatchSize=N)."
506+
),
498507
)
499508
prefill_group.add_argument(
500509
"--no-cuda-graphs",

src/main/java/org/beehive/gpullama3/Options.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,20 @@
55
import java.nio.file.Paths;
66

77
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo,
8-
boolean useTornadovm) {
8+
boolean useTornadovm, boolean withPrefillDecode, int batchPrefillSize) {
99

1010
public static final int DEFAULT_MAX_TOKENS = 1024;
1111

1212
public Options {
1313
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
1414
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
1515
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
16+
require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1");
17+
require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode");
18+
// Publish to system properties so TornadoVMMasterPlan and Llama read the right values
19+
// even when the JAR is invoked directly (without the Python launcher).
20+
if (withPrefillDecode) System.setProperty("llama.withPrefillDecode", "true");
21+
if (batchPrefillSize > 1) System.setProperty("llama.prefillBatchSize", String.valueOf(batchPrefillSize));
1622
}
1723

1824
static void require(boolean condition, String messageFormat, Object... args) {
@@ -44,6 +50,8 @@ public static void printUsage(PrintStream out) {
4450
out.println(" --max-tokens, -n <int> number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
4551
out.println(" --stream <boolean> print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
4652
out.println(" --echo <boolean> print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
53+
out.println(" --with-prefill-decode enable prefill/decode separation (skip logits during prefill)");
54+
out.println(" --batch-prefill-size <int> batched prefill chunk size; requires --with-prefill-decode, must be > 1, enables batched CPU/GPU prefill");
4755
out.println();
4856
}
4957

@@ -61,7 +69,7 @@ public static Options getDefaultOptions() {
6169
boolean echo = false;
6270
boolean useTornadoVM = getDefaultTornadoVM();
6371

64-
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM);
72+
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM, false, 1);
6573
}
6674

6775
public static Options parseOptions(String[] args) {
@@ -77,13 +85,16 @@ public static Options parseOptions(String[] args) {
7785
boolean stream = false;
7886
boolean echo = false;
7987
Boolean useTornadovm = null; // null means not specified via command line
88+
boolean withPrefillDecode = false;
89+
int batchPrefillSize = 1;
8090

8191
for (int i = 0; i < args.length; i++) {
8292
String optionName = args[i];
8393
require(optionName.startsWith("-"), "Invalid option %s", optionName);
8494
switch (optionName) {
8595
case "--interactive", "--chat", "-i" -> interactive = true;
8696
case "--instruct" -> interactive = false;
97+
case "--with-prefill-decode" -> withPrefillDecode = true;
8798
case "--help", "-h" -> {
8899
printUsage(System.out);
89100
System.exit(0);
@@ -111,6 +122,7 @@ public static Options parseOptions(String[] args) {
111122
case "--stream" -> stream = Boolean.parseBoolean(nextArg);
112123
case "--echo" -> echo = Boolean.parseBoolean(nextArg);
113124
case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg);
125+
case "--batch-prefill-size" -> batchPrefillSize = Integer.parseInt(nextArg);
114126
default -> require(false, "Unknown option: %s", optionName);
115127
}
116128
}
@@ -123,6 +135,6 @@ public static Options parseOptions(String[] args) {
123135
useTornadovm = getDefaultTornadoVM();
124136
}
125137

126-
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm);
138+
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm, withPrefillDecode, batchPrefillSize);
127139
}
128140
}

src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.beehive.gpullama3.tokenizer.Tokenizer;
1212
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
1313
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode;
14-
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanStandard;
1514
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
1615

1716
import java.util.ArrayList;
@@ -36,8 +35,8 @@
3635
* Behaviour is identical to the baseline decode path.</li>
3736
* </ol>
3837
*
39-
* <p>Activated by {@code -Dllama.batchedPrefill=true} (set via
40-
* {@code --batched-prefill} in the Python launcher).</p>
38+
* <p>Activated by {@code -Dllama.withPrefillDecode=true} (set via
39+
* {@code --with-prefill-decode} in the Python launcher).</p>
4140
*/
4241
public final class InferenceEngineWithPrefillDecode {
4342

@@ -269,11 +268,10 @@ public static List<Integer> generateTokensGPULlama(
269268
} else {
270269
// ── Phase 2: Sequential GPU Prefill + Decode ─────────────────────────
271270

272-
// Thin wrapper: no new TornadoVM plan created, just holds the reference
273-
// Plan is a TornadoVMMasterPlanStandard when PREFILL_BATCH_SIZE == 1.
271+
// Plan was initialized by TornadoVMMasterPlan.initializeTornadoVMPlan as
272+
// TornadoVMMasterPlanWithPrefillDecode when WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE == 1.
274273
TornadoVMMasterPlanWithPrefillDecode prefillPlan =
275-
new TornadoVMMasterPlanWithPrefillDecode(
276-
(TornadoVMMasterPlanStandard) tornadoVMPlan, state, model);
274+
(TornadoVMMasterPlanWithPrefillDecode) tornadoVMPlan;
277275

278276
// ── Phase 1: Prefill (GPU, no logits) ────────────────────────────────
279277
for (int promptIndex = 0; promptIndex < promptTokens.size() && pos < actualMaxTokens; promptIndex++) {

src/main/java/org/beehive/gpullama3/model/llama/Llama.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
public class Llama extends AbstractModel {
2222

23-
static final boolean BATCHED_PREFILL = Boolean.getBoolean("llama.batchedPrefill");
23+
static final boolean WITH_PREFILL_DECODE = Boolean.getBoolean("llama.withPrefillDecode");
2424

2525
LlamaConfiguration configuration;
2626

@@ -66,7 +66,7 @@ public void forward(State state, int token, int position) {
6666
@Override
6767
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
6868
IntConsumer onTokenGenerated) {
69-
if (BATCHED_PREFILL) {
69+
if (WITH_PREFILL_DECODE) {
7070
return InferenceEngineWithPrefillDecode.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
7171
}
7272
return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
@@ -75,7 +75,7 @@ public List<Integer> generateTokens(State state, int startPosition, List<Integer
7575
@Override
7676
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
7777
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
78-
if (BATCHED_PREFILL) {
78+
if (WITH_PREFILL_DECODE) {
7979
return InferenceEngineWithPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
8080
}
8181
return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,41 +30,52 @@ public interface TornadoVMMasterPlan {
3030
boolean CUDA_GRAPHS = Boolean.parseBoolean(
3131
System.getProperty("llama.cudaGraphs", "true"));
3232

33-
/**
34-
* Single-token forward pass returning output logits.
35-
*
36-
* <p>Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM})
37-
* and the Phase 2 sequential decode path. Not applicable to
38-
* {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.</p>
39-
*
40-
* @param position sequence position of the current token
41-
* @return logits array for token sampling
42-
*/
43-
FloatArray tornadoVMForwardExecuteLayered(int position);
33+
boolean WITH_PREFILL_DECODE = Boolean.getBoolean("llama.withPrefillDecode");
4434

45-
/** Releases all device memory held by this plan. */
46-
void freeTornadoExecutionPlan();
35+
int PREFILL_BATCH_SIZE = Integer.getInteger("llama.prefillBatchSize", 1);
4736

4837
/**
4938
* Factory: creates, JIT-compiles, and warms up the appropriate plan.
5039
*
51-
* <p>When {@code llama.prefillBatchSize > 1} a {@link TornadoVMMasterPlanWithBatchPrefillDecode}
52-
* is returned; otherwise a {@link TornadoVMMasterPlanStandard} is returned.</p>
40+
* <p>When {@code llama.withPrefillDecode=true} and {@code llama.prefillBatchSize > 1},
41+
* a {@link TornadoVMMasterPlanWithBatchPrefillDecode} is returned.
42+
* Otherwise a {@link TornadoVMMasterPlanStandard} is returned (used for the baseline
43+
* path and the sequential prefill/decode path when batch size is 1).</p>
5344
*
5445
* @param state the model state (must be {@link LlamaState} when batch size {@code > 1})
5546
* @param model the model instance
5647
* @return the initialized plan, also stored via {@link Model#setTornadoVMPlan}
5748
*/
5849
static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) {
59-
int batchSize = Integer.getInteger("llama.prefillBatchSize", 1);
6050
TornadoVMMasterPlan plan;
61-
if (batchSize > 1) {
51+
52+
if (WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE > 1) {
53+
// GPU path with batched prefill/decode
6254
plan = TornadoVMMasterPlanWithBatchPrefillDecode.initializeUnifiedPlan(
63-
(LlamaState) state, model, batchSize);
55+
(LlamaState) state, model, PREFILL_BATCH_SIZE);
56+
} else if (WITH_PREFILL_DECODE) {
57+
// GPU path with simple prefill/decode
58+
plan = TornadoVMMasterPlanWithPrefillDecode.initialize(state, model);
6459
} else {
60+
// GPU path with no prefill/decode
6561
plan = TornadoVMMasterPlanStandard.initialize(state, model);
6662
}
6763
model.setTornadoVMPlan(plan);
6864
return plan;
6965
}
66+
67+
/**
68+
* Single-token forward pass returning output logits.
69+
*
70+
* <p>Used by the standard GPU path ({@link org.beehive.gpullama3.inference.InferenceCore#forwardTornadoVM})
71+
* and the Phase 2 sequential decode path. Not applicable to
72+
* {@link TornadoVMMasterPlanWithBatchPrefillDecode} — that plan uses its own typed methods.</p>
73+
*
74+
* @param position sequence position of the current token
75+
* @return logits array for token sampling
76+
*/
77+
FloatArray tornadoVMForwardExecuteLayered(int position);
78+
79+
/** Releases all device memory held by this plan. */
80+
void freeTornadoExecutionPlan();
7081
}

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525
* <p>For decode, {@link #tornadoVMForwardDecode} delegates to the wrapped
2626
* plan's {@code tornadoVMForwardExecuteLayered}, preserving identical behaviour
2727
* to the baseline GPU path.</p>
28+
*
29+
* <p>Implements {@link TornadoVMMasterPlan} so it can be returned by the factory
30+
* and stored in the model; {@link #tornadoVMForwardExecuteLayered} delegates to
31+
* {@link #tornadoVMForwardDecode}.</p>
2832
*/
29-
public class TornadoVMMasterPlanWithPrefillDecode {
33+
public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan {
3034

3135
private final TornadoVMMasterPlanStandard plan;
3236
private final State state;
@@ -38,6 +42,12 @@ public TornadoVMMasterPlanWithPrefillDecode(TornadoVMMasterPlanStandard plan, St
3842
this.config = model.configuration();
3943
}
4044

45+
/** Factory: initializes the inner standard plan then wraps it. */
46+
public static TornadoVMMasterPlanWithPrefillDecode initialize(State state, Model model) {
47+
TornadoVMMasterPlanStandard inner = TornadoVMMasterPlanStandard.initialize(state, model);
48+
return new TornadoVMMasterPlanWithPrefillDecode(inner, state, model);
49+
}
50+
4151
/**
4252
* GPU prefill forward: runs preprocessing + all transformer layers, skips logits.
4353
*
@@ -76,4 +86,15 @@ public void tornadoVMForwardPrefill(int position) {
7686
public FloatArray tornadoVMForwardDecode(int position) {
7787
return plan.tornadoVMForwardExecuteLayered(position);
7888
}
89+
90+
/** Delegates to the wrapped plan's full forward pass (used by the standard decode path). */
91+
@Override
92+
public FloatArray tornadoVMForwardExecuteLayered(int position) {
93+
return tornadoVMForwardDecode(position);
94+
}
95+
96+
@Override
97+
public void freeTornadoExecutionPlan() {
98+
plan.freeTornadoExecutionPlan();
99+
}
79100
}

0 commit comments

Comments
 (0)