11package org .beehive .gpullama3 .tornadovm ;
22
3+ import org .beehive .gpullama3 .inference .state .LlamaState ;
34import org .beehive .gpullama3 .inference .state .State ;
4- import org .beehive .gpullama3 .model . Configuration ;
5+ import org .beehive .gpullama3 .inference . weights . tornado . LlamaTornadoWeights ;
56import org .beehive .gpullama3 .model .Model ;
6- import org .beehive .gpullama3 .tensor .GGMLType ;
7- import org .beehive .gpullama3 .tornadovm .layerplanner .GenericLayerPlanner ;
8- import org .beehive .gpullama3 .tornadovm .layerplanner .QuantizationPlannerFactory ;
7+ import org .beehive .gpullama3 .model .llama .LlamaConfiguration ;
8+ import org .beehive .gpullama3 .tornadovm .kernels .TransformerComputeKernels ;
9+ import org .beehive .gpullama3 .tornadovm .layerplanner .WorkerGridFactory ;
10+ import org .beehive .gpullama3 .tornadovm .layerplanner .strategy .SchedulerDetectionService ;
11+ import org .beehive .gpullama3 .tornadovm .layerplanner .strategy .SchedulerType ;
12+ import org .beehive .gpullama3 .tornadovm .layers .type .fp16 .decode .LlamaFP16FFNLayersPrefillDecode ;
13+ import org .beehive .gpullama3 .tornadovm .layers .type .fp16 .decode .LogitsFP16LayerDecode ;
14+ import uk .ac .manchester .tornado .api .GridScheduler ;
915import uk .ac .manchester .tornado .api .ImmutableTaskGraph ;
16+ import uk .ac .manchester .tornado .api .KernelContext ;
17+ import uk .ac .manchester .tornado .api .TaskGraph ;
1018import uk .ac .manchester .tornado .api .TornadoExecutionPlan ;
19+ import uk .ac .manchester .tornado .api .enums .DataTransferMode ;
1120import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
21+ import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
22+
23+ import java .util .ArrayList ;
24+ import java .util .List ;
1225
1326/**
1427 * GPU execution plan for single-token prefill/decode separation.
1528 *
16- * <p>Uses the same single-token execution plan as {@link TornadoVMMasterPlanStandard}
17- * but exposes two distinct forward passes:</p>
18- * <ul>
19- * <li>{@link #tornadoVMForwardPrefill} — runs graphs 0..N, skips the logits graph.
20- * Called for each prompt token; KV cache is populated but logits are discarded.</li>
21- * <li>{@link #tornadoVMForwardDecode} — full execution including logits.
22- * Called for each generated token.</li>
23- * </ul>
29+ * <p>Uses dedicated layer classes that carry correct cross-graph
30+ * {@code consumeFromDevice} source names for both CUDA-graph and interpreter
31+ * (no-CUDA-graph) mode. All graphs are owned by this plan and built from scratch —
32+ * no reuse of the standard execution path.</p>
2433 *
25- * <p>Graph layout (same as {@link TornadoVMMasterPlanStandard} ):</p>
34+ * <p>Graph layout (N+2 graphs total ):</p>
2635 * <pre>
27- * graph 0 : preprocessing (embedding setup)
28- * graphs 1..N : transformer layers
29- * graph N+1 : logits projection ( final RMSNorm + wcls matmul)
36+ * [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution
37+ * [ 1..N] layer_0..layer_N-1 transformer layers (attention + FFN)
38+ * [ N+1] logits final RMSNorm + wcls matmul
3039 * </pre>
40+ *
41+ * <p>Two distinct forward passes:</p>
42+ * <ul>
43+ * <li>{@link #tornadoVMForwardPrefill} — runs graphs 0..N, skips logits.
44+ * KV cache is populated for each prompt token; logits are discarded.</li>
45+ * <li>{@link #tornadoVMForwardDecode} — full pass including logits.
46+ * Called for each generated token.</li>
47+ * </ul>
3148 */
3249public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan {
3350
34- private final State state ;
35- private final Model model ;
36- private final Configuration config ;
51+ private final LlamaState state ;
52+ private final Model model ;
53+ private final LlamaConfiguration config ;
54+ private final int N ; // numberOfLayers
55+ private final TornadoExecutionPlan executionPlan ;
56+ private final GridScheduler gridScheduler ;
3757
38- GenericLayerPlanner tornadoVMLayerPlanner ;
39- public TornadoExecutionPlan executionPlan ;
58+ // ── Graph-index helpers ───────────────────────────────────────────────────
59+ private int activationIdx () { return 0 ; }
60+ private int layerIdx (int i ) { return 1 + i ; }
61+ private int logitsIdx () { return N + 1 ; }
4062
41- public TornadoVMMasterPlanWithPrefillDecode (State state , Model model ) {
63+ // ── Construction ─────────────────────────────────────────────────────────
64+ TornadoVMMasterPlanWithPrefillDecode (State initialState , Model model ) {
4265 long startTime = System .nanoTime ();
4366 long planCreationTime = 0 ;
4467 long warmupTime = 0 ;
@@ -47,81 +70,127 @@ public TornadoVMMasterPlanWithPrefillDecode(State state, Model model) {
4770 System .err .println ("\n Starting TornadoVM initialization..." );
4871 }
4972
50- this .state = state ;
51- this .model = model ;
52- this .config = model .configuration ();
73+ this .state = (LlamaState ) initialState ;
74+ this .model = model ;
75+ this .config = (LlamaConfiguration ) model .configuration ();
76+ this .N = config .numberOfLayers ();
77+ this .gridScheduler = new GridScheduler ();
5378 this .executionPlan = createExecutionPlan ();
5479
5580 if (ENABLE_TORNADOVM_INIT_TIME ) {
5681 planCreationTime = System .nanoTime ();
57- System .err .printf ("TornadoVM GPU single-token prefill/decode execution plan creation: %.2f ms\n " , (planCreationTime - startTime ) / 1_000_000.0 );
82+ System .err .printf ("TornadoVM GPU single-token prefill/decode execution plan creation: %.2f ms\n " ,
83+ (planCreationTime - startTime ) / 1_000_000.0 );
5884 }
5985
6086 if (CUDA_GRAPHS ) executionPlan .withAllGraphs ().withCUDAGraph ();
6187 executionPlan .withPreCompilation ();
6288
6389 if (ENABLE_TORNADOVM_INIT_TIME ) {
6490 warmupTime = System .nanoTime ();
65- System .err .printf ("Java to GPU JIT compiler warmup: %.2f ms\n " , (warmupTime - planCreationTime ) / 1_000_000.0 );
91+ System .err .printf ("Java to GPU JIT compiler warmup: %.2f ms\n " ,
92+ (warmupTime - planCreationTime ) / 1_000_000.0 );
6693 }
6794
6895 forceCopyInReadOnlyData ();
6996
7097 if (ENABLE_TORNADOVM_INIT_TIME ) {
7198 long copyTime = System .nanoTime ();
72- System .err .printf ("Transfer read-only weights to GPU: %.2f ms\n " , (copyTime - warmupTime ) / 1_000_000.0 );
99+ System .err .printf ("Transfer read-only weights to GPU: %.2f ms\n " ,
100+ (copyTime - warmupTime ) / 1_000_000.0 );
73101 System .err .printf ("Finished TornadoVM initialization...\n \n " );
74102 }
75103 }
76104
105+ // ── Activation graph ─────────────────────────────────────────────────────
106+
77107 /**
78- * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*.
79- * Prefill is token-by-token but does not compute logits.
108+ * Graph 0: single-token FP16 → FP32.
109+ *
110+ * <p>Outputs {@code wrapX} (FP32 hidden state) and persists it on device so that
111+ * decode layer 0 can pick it up via {@code consumeFromDevice("decodeActivation", wrapX)}.
112+ * The KV cache is <em>not</em> managed here — it is allocated on the first forward pass
113+ * by decode layer 0 via {@code FIRST_EXECUTION}.</p>
80114 */
115+ private TaskGraph buildActivationGraph (KernelContext ctx ) {
116+ return new TaskGraph ("decodeActivation" )
117+ .transferToDevice (DataTransferMode .EVERY_EXECUTION , state .embeddingX )
118+ .task ("updateX" , TransformerComputeKernels ::convertFP16toFP32 ,
119+ ctx , (HalfFloatArray ) state .embeddingX , state .wrapX )
120+ .persistOnDevice (state .wrapX );
121+ }
122+
123+ // ── Plan construction ─────────────────────────────────────────────────────
124+
81125 @ Override
82126 public TornadoExecutionPlan createExecutionPlan () {
83- GGMLType weightType = model .weights ().getWeightType ();
84- this .tornadoVMLayerPlanner = QuantizationPlannerFactory .create (weightType , state , model );
85- var taskGraphs = tornadoVMLayerPlanner .getImmutableTaskGraphs ();
86- var taskGraphArray = taskGraphs .toArray (new ImmutableTaskGraph [taskGraphs .size ()]);
87- return new TornadoExecutionPlan (taskGraphArray );
127+ LlamaTornadoWeights weights = (LlamaTornadoWeights ) model .weights ();
128+ SchedulerType schedulerType = SchedulerDetectionService .determineSchedulerType (model );
129+
130+ List <ImmutableTaskGraph > all = new ArrayList <>(N + 2 );
131+
132+ // [0] Activation ──────────────────────────────────────────────────────
133+ KernelContext actCtx = new KernelContext ();
134+ all .add (buildActivationGraph (actCtx ).snapshot ());
135+ gridScheduler .addWorkerGrid ("decodeActivation.updateX" ,
136+ WorkerGridFactory .genericWorker (config .dim (), 128 ));
137+
138+ // [1..N] Decode layer graphs ──────────────────────────────────────────
139+ // Layer 0: FIRST_EXECUTION for KV cache + consumeFromDevice("decodeActivation", wrapX).
140+ // Layers 1+: consumeFromDevice with explicit predecessor names for interpreter mode.
141+ LlamaFP16FFNLayersPrefillDecode decodeLayers =
142+ new LlamaFP16FFNLayersPrefillDecode ("decode" , state , weights , config , schedulerType );
143+ all .addAll (decodeLayers .getFFNLayerImmutableTaskGraphs ());
144+ decodeLayers .updateGridScheduler (gridScheduler );
145+
146+ // [N+1] Logits ────────────────────────────────────────────────────────
147+ // LogitsFP16LayerDecode re-persists the KV cache so the pointer survives
148+ // the logits → layer_0 KV-cache FIRST_EXECUTION boundary across decode tokens.
149+ LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode (
150+ "logits" , state , weights , config ,
151+ decodeLayers .getLastFFNLayerTaskGraphID (), schedulerType );
152+ all .add (logitsLayer .getImmutableTaskGraph ());
153+ logitsLayer .updateGridScheduler (gridScheduler );
154+
155+ return new TornadoExecutionPlan (all .toArray (new ImmutableTaskGraph [0 ]));
88156 }
89157
158+ // ── Initialisation ────────────────────────────────────────────────────────
159+
160+ /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */
90161 @ Override
91162 public void forceCopyInReadOnlyData () {
92163 state .wrapX .clear ();
93164 state .positionHolder .init (0 );
94165
95- executionPlan . withGraph ( 0 ). withGridScheduler ( tornadoVMLayerPlanner . getGridScheduler ()). execute ();
96-
97- for ( int layer = 0 ; layer < config . numberOfLayers (); layer ++) {
98- executionPlan . withGraph ( layer + 1 ). withGridScheduler ( tornadoVMLayerPlanner . getGridScheduler ()) .execute ();
166+ for ( int i = 0 ; i <= logitsIdx (); i ++) {
167+ var g = executionPlan . withGraph ( i ). withGridScheduler ( gridScheduler );
168+ if ( CUDA_GRAPHS ) g . withCUDAGraph ();
169+ g .execute ();
99170 }
100-
101- executionPlan .withGraph (config .numberOfLayers () + 1 ).withGridScheduler (tornadoVMLayerPlanner .getGridScheduler ()).execute ();
102171 }
103172
173+ // ── Forward passes ────────────────────────────────────────────────────────
174+
104175 /**
105- * GPU prefill forward: runs preprocessing + all transformer layers, skips logits.
106- * KV cache is populated; logits projection is intentionally omitted .
176+ * GPU prefill forward: activation + all transformer layers, logits skipped .
177+ * KV cache is populated for each prompt token .
107178 *
108179 * @param position sequence position being processed
109180 */
110181 public void tornadoVMForwardPrefill (int position ) {
111- // Graph 0: preprocessing
112- executionPlan .withGraph (0 )
113- .withGridScheduler (tornadoVMLayerPlanner .getGridScheduler ())
114- .execute ();
182+ var prefillActivation = executionPlan .withGraph (activationIdx ()).withGridScheduler (gridScheduler );
183+ if (CUDA_GRAPHS ) prefillActivation .withCUDAGraph ();
184+ prefillActivation .execute ();
115185
116186 state .positionHolder .set (0 , position );
117187 state .temp .clear ();
118188 state .tempFFN .clear ();
119189
120- // Graphs 1..N: transformer layers (logits graph N+1 intentionally skipped)
121- for (int layer = 1 ; layer <= config .numberOfLayers (); layer ++) {
122- executionPlan .withGraph (layer )
123- .withGridScheduler (tornadoVMLayerPlanner .getGridScheduler ())
124- .execute ();
190+ for (int layer = 0 ; layer < N ; layer ++) {
191+ var prefillLayer = executionPlan .withGraph (layerIdx (layer )).withGridScheduler (gridScheduler );
192+ if (CUDA_GRAPHS ) prefillLayer .withCUDAGraph ();
193+ prefillLayer .execute ();
125194 }
126195 }
127196
@@ -137,27 +206,25 @@ public FloatArray tornadoVMForwardDecode(int position) {
137206
138207 @ Override
139208 public FloatArray tornadoVMForwardExecuteLayered (int position ) {
140- var preGraph = executionPlan .withGraph (0 )
141- .withGridScheduler (tornadoVMLayerPlanner .getGridScheduler ());
142- if (CUDA_GRAPHS ) preGraph .withCUDAGraph ();
143- preGraph .execute ();
209+ var act = executionPlan .withGraph (activationIdx ()).withGridScheduler (gridScheduler );
210+ if (CUDA_GRAPHS ) act .withCUDAGraph ();
211+ act .execute ();
144212
145213 state .positionHolder .set (0 , position );
146214 state .temp .clear ();
147215 state .tempFFN .clear ();
148216
149- for (int layer = 0 ; layer < config . numberOfLayers () ; layer ++) {
150- executionPlan .withGraph (1 + layer )
151- . withGridScheduler ( tornadoVMLayerPlanner . getGridScheduler ())
152- .execute ();
217+ for (int layer = 0 ; layer < N ; layer ++) {
218+ var l = executionPlan .withGraph (layerIdx ( layer )). withGridScheduler ( gridScheduler );
219+ if ( CUDA_GRAPHS ) l . withCUDAGraph ();
220+ l .execute ();
153221 }
154222
155223 state .tempLogits .clear ();
156224 state .wrapLogits .clear ();
157- var logitsGraph = executionPlan .withGraph (config .numberOfLayers () + 1 )
158- .withGridScheduler (tornadoVMLayerPlanner .getGridScheduler ());
159- if (CUDA_GRAPHS ) logitsGraph .withCUDAGraph ();
160- logitsGraph .execute ();
225+ var logits = executionPlan .withGraph (logitsIdx ()).withGridScheduler (gridScheduler );
226+ if (CUDA_GRAPHS ) logits .withCUDAGraph ();
227+ logits .execute ();
161228
162229 return state .wrapLogits ;
163230 }
0 commit comments