Skip to content

Commit 1fdc3c0

Browse files
[prf/dec] Update GPU execution plans to clarify prefill/decode structure and KV cache handling
1 parent d74a228 commit 1fdc3c0

4 files changed

Lines changed: 42 additions & 32 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,23 @@
88
/**
99
* Common contract for all TornadoVM GPU execution plans.
1010
*
11-
* <p>Two concrete implementations exist:</p>
11+
* <p>Three concrete implementations exist:</p>
1212
* <ul>
13-
* <li>{@link TornadoVMMasterPlanStandard} — single-token forward pass; used for the
14-
* baseline GPU path and Phase 2 sequential prefill/decode.</li>
15-
* <li>{@link TornadoVMMasterPlanWithBatchPrefillDecode} — unified plan for Phase 4 batched
16-
* prefill + single-token decode within one {@code TornadoExecutionPlan}.</li>
13+
* <li>{@link TornadoVMMasterPlanStandard} — baseline single-token forward pass
14+
* (preprocessing + N layers + logits).</li>
15+
* <li>{@link TornadoVMMasterPlanWithPrefillDecode} — sequential prefill/decode separation;
16+
* reuses the same N layer graphs for both phases, skipping logits during prefill.</li>
17+
* <li>{@link TornadoVMMasterPlanWithBatchPrefillDecode} — batched prefill + single-token
18+
* decode; holds 2N+3 graphs in one plan to keep the KV cache on device across phases.</li>
1719
* </ul>
1820
*
19-
* <p>The {@link #initializeTornadoVMPlan} factory selects the appropriate implementation
20-
* based on {@code llama.prefillBatchSize}: if {@code > 1}, returns a
21-
* {@link TornadoVMMasterPlanWithBatchPrefillDecode}; otherwise returns a
22-
* {@link TornadoVMMasterPlanStandard}.</p>
21+
* <p>The {@link #initializeTornadoVMPlan} factory selects the implementation based on
22+
* {@code llama.withPrefillDecode} and {@code llama.prefillBatchSize}:</p>
23+
* <ul>
24+
* <li>{@code withPrefillDecode=false} → {@link TornadoVMMasterPlanStandard}</li>
25+
* <li>{@code withPrefillDecode=true}, {@code prefillBatchSize=1} → {@link TornadoVMMasterPlanWithPrefillDecode}</li>
26+
* <li>{@code withPrefillDecode=true}, {@code prefillBatchSize>1} → {@link TornadoVMMasterPlanWithBatchPrefillDecode}</li>
27+
* </ul>
2328
*/
2429
public interface TornadoVMMasterPlan {
2530

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
* Standard (single-token) GPU execution plan.
1515
*
1616
* <p>Processes one token at a time through preprocessing + N transformer layers +
17-
* logits projection. Used for both the baseline GPU path and the Phase 2
18-
* sequential prefill/decode path.</p>
17+
* logits projection.
18+
* </p>
1919
*/
2020
public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan {
2121

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,25 @@
2929
/**
3030
* GPU execution plan for batched prefill + single-token decode.
3131
*
32-
* <p>A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache
33-
* ({@code wrapKeyCache}, {@code wrapValueCache}) is shared on device via
34-
* {@code persistOnDevice}/{@code consumeFromDevice}. Two separate plans would
35-
* allocate independent device buffers and lose the prefill KV state.</p>
32+
* <p>A single {@link TornadoExecutionPlan} holds all {@link TaskGraph} for
33+
* batched prefill and single-token decode phases with the following structure:</p>.
3634
*
37-
* <p>Graph layout (2N+3 graphs total):</p>
35+
* <p>TaskGraph layout (2N+3 TaskGraphs total):</p>
3836
* <pre>
39-
* [0] batch activation B×dim FP16 → FP32
40-
* [1..N] batch layer graphs B tokens, all transformer ops
41-
* [N+1] decode activation single-token FP16 → FP32 + KV-cache pass-through
42-
* [N+2..2N+1] decode layer graphs single-token, standard kernels
37+
* [0] prefill batch activation B×dim FP16 → FP32
38+
* [1..N] prefill batch layer graphs B tokens, all transformer ops
39+
* [N+1] decode activation single-token FP16 → FP32 + KV-cache pass-through
40+
* [N+2..2N+1] decode layer graphs single-token, standard kernels
4341
* [2N+2] logits graph
4442
* </pre>
4543
*
44+
* <p>
45+
* Incorporating cross-phase {@link TaskGraph}s withing a single {@link TornadoExecutionPlan}
46+
* is necessary to enable KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) sharing
47+
* across prefill and decode phases. The KV cache pointers are chained across {@link TaskGraph}s
48+
* via the {@code persistOnDevice}/{@code consumeFromDevice} API within the {@link TornadoExecutionPlan}.
49+
* </p>
50+
*
4651
* <p>KV cache pointer chain across phases:</p>
4752
* <pre>
4853
* batchLayer[N-1] --persistOnDevice(wrapKeyCache)-→

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,26 @@
2525
import java.util.List;
2626

2727
/**
28-
* GPU execution plan for single-token prefill/decode separation.
28+
* GPU execution plan for sequential (single-token) prefill/decode separation.
2929
*
30-
* <p>Uses dedicated layer classes that carry correct cross-graph
31-
* {@code consumeFromDevice} source names for both CUDA-graph and interpreter
32-
* (no-CUDA-graph) mode. All graphs are owned by this plan and built from scratch —
33-
* no reuse of the standard execution path.</p>
30+
* <p>A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache
31+
* ({@code wrapKeyCache}, {@code wrapValueCache}) is allocated once and remains on
32+
* device across both phases. Prefill and decode reuse the same N layer graphs;
33+
* only the logits graph is skipped during prefill.</p>
3434
*
3535
* <p>Graph layout (N+2 graphs total):</p>
3636
* <pre>
37-
* [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution
38-
* [1..N] layer_0..layer_N-1 transformer layers (attention + FFN)
39-
* [N+1] logits final RMSNorm + wcls matmul
37+
* [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution
38+
* [1..N] layer_0..layer_N-1 transformer layers (attention + FFN)
39+
* [N+1] logits final RMSNorm + wcls matmul
4040
* </pre>
4141
*
42-
* <p>Two distinct forward passes:</p>
42+
* <p>Two forward passes:</p>
4343
* <ul>
44-
* <li>{@link #tornadoVMForwardPrefill} — runs graphs 0..N, skips logits.
45-
* KV cache is populated for each prompt token; logits are discarded.</li>
44+
* <li>{@link #tornadoVMForwardPrefill} — graphs 0..N (activation + layers), logits skipped.
45+
* Called once per prompt token; populates the KV cache.</li>
4646
* <li>{@link #tornadoVMForwardDecode} — full pass including logits.
47-
* Called for each generated token.</li>
47+
* Called once per generated token; returns logits for sampling.</li>
4848
* </ul>
4949
*/
5050
public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan {

0 commit comments

Comments
 (0)