11package org .beehive .gpullama3 .tornadovm ;
22
33import org .beehive .gpullama3 .inference .state .LlamaState ;
4+ import org .beehive .gpullama3 .inference .state .State ;
45import org .beehive .gpullama3 .inference .weights .tornado .LlamaTornadoWeights ;
56import org .beehive .gpullama3 .model .Model ;
67import org .beehive .gpullama3 .model .llama .LlamaConfiguration ;
2526import java .util .List ;
2627
2728/**
28- * Unified GPU execution plan for Phase 4: batched prefill + single-token decode.
29+ * GPU execution plan for batched prefill + single-token decode.
2930 *
3031 * <p>A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache
3132 * ({@code wrapKeyCache}, {@code wrapValueCache}) is shared on device via
5051 */
5152public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan {
5253
53- private static final boolean ENABLE_TIMING =
54- Boolean .parseBoolean (System .getProperty ("llama.EnableTimingForTornadoVMInit" , "False" ));
55-
5654 private final LlamaState state ;
55+ private final Model model ;
5756 private final LlamaConfiguration config ;
5857 private final int batchSize ;
5958 private final int N ; // numberOfLayers
@@ -68,55 +67,44 @@ public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMaste
6867 private int logitsIdx () { return 2 * N + 2 ; }
6968
7069 // ── Construction ─────────────────────────────────────────────────────────
71- private TornadoVMMasterPlanWithBatchPrefillDecode (LlamaState state , Model model , int batchSize ) {
72- this .state = state ;
70+ TornadoVMMasterPlanWithBatchPrefillDecode (State initialState , Model model ) {
71+ long startTime = System .nanoTime ();
72+ long planCreationTime = 0 ;
73+ long warmupTime = 0 ;
74+
75+ if (ENABLE_TORNADOVM_INIT_TIME ) {
76+ System .err .println ("\n Starting TornadoVM initialization..." );
77+ }
78+
79+ this .state = (LlamaState ) initialState ; // only LlamaFP16 supports batched prefill for now
80+ this .model = model ;
7381 this .config = (LlamaConfiguration ) model .configuration ();
74- this .batchSize = batchSize ;
82+ this .batchSize = PREFILL_BATCH_SIZE ;
7583 this .N = config .numberOfLayers ();
7684
77- LlamaTornadoWeights weights = (LlamaTornadoWeights ) model .weights ();
78- SchedulerType schedulerType = SchedulerDetectionService .determineSchedulerType (model );
79-
80- List <ImmutableTaskGraph > all = new ArrayList <>(2 * N + 3 );
81- GridScheduler scheduler = new GridScheduler ();
85+ this .gridScheduler = new GridScheduler ();
86+ this .executionPlan = createExecutionPlan ();
8287
83- // [0] Batch prefill activation ────────────────────────────────────────────────
84- KernelContext batchActCtx = new KernelContext ();
85- all .add (buildBatchPrefillActivationGraph (batchActCtx ).snapshot ());
86- scheduler .addWorkerGrid ("batchActivation.batchUpdateX" ,
87- WorkerGridFactory .genericWorker (batchSize * config .dim (), 128 ));
88+ if (ENABLE_TORNADOVM_INIT_TIME ) {
89+ planCreationTime = System .nanoTime ();
90+ System .err .printf ("TornadoVM GPU batched prefill/decode execution plan creation: %.2f ms\n " , (planCreationTime - startTime ) / 1_000_000.0 );
91+ }
8892
89- // [1..N] Batch prefill layer graphs ───────────────────────────────────────────
90- LlamaFP16LayersBatchPrefill batchLayers =
91- new LlamaFP16LayersBatchPrefill (state , weights , config , batchSize );
92- all .addAll (batchLayers .getLayerImmutableTaskGraphs ());
93- batchLayers .updateGridScheduler (scheduler );
93+ if (CUDA_GRAPHS ) executionPlan .withAllGraphs ().withCUDAGraph ();
94+ executionPlan .withPreCompilation ();
9495
95- // [N+1] Decode activation (with KV-cache pass-through) ────────────────
96- KernelContext decodeActCtx = new KernelContext ();
97- all .add (buildDecodeActivationGraph (decodeActCtx , batchLayers .getLastLayerTaskGraphID ()).snapshot ());
98- scheduler .addWorkerGrid ("decodeActivationUpdate.updateX" ,
99- WorkerGridFactory .genericWorker (config .dim (), 128 ));
96+ if (ENABLE_TORNADOVM_INIT_TIME ) {
97+ warmupTime = System .nanoTime ();
98+ System .err .printf ("Java to GPU JIT compiler warmup: %.2f ms\n " , (warmupTime - planCreationTime ) / 1_000_000.0 );
99+ }
100100
101- // [N+2..2N+1] Decode layer graphs ────────────────────────────────────
102- // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload).
103- LlamaFP16FFNLayersDecode decodeLayers =
104- new LlamaFP16FFNLayersDecode (
105- "llamaFFNDecode" , state , weights , config , schedulerType );
106- all .addAll (decodeLayers .getFFNLayerImmutableTaskGraphs ());
107- decodeLayers .updateGridScheduler (scheduler );
101+ forceCopyInReadOnlyData ();
108102
109- // [2N+2] Logits ───────────────────────────────────────────────────────
110- // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache)
111- // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the
112- // KV-cache pointer survives the logits → decode-activation boundary across tokens.
113- LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode ("logits" , state , weights , config ,
114- decodeLayers .getLastFFNLayerTaskGraphID (), schedulerType );
115- all .add (logitsLayer .getImmutableTaskGraph ());
116- logitsLayer .updateGridScheduler (scheduler );
117-
118- this .gridScheduler = scheduler ;
119- this .executionPlan = new TornadoExecutionPlan (all .toArray (new ImmutableTaskGraph [0 ]));
103+ if (ENABLE_TORNADOVM_INIT_TIME ) {
104+ long copyTime = System .nanoTime ();
105+ System .err .printf ("Transfer read-only weights to GPU: %.2f ms\n " , (copyTime - warmupTime ) / 1_000_000.0 );
106+ System .err .printf ("Finished TornadoVM initialization...\n \n " );
107+ }
120108 }
121109
122110 // ── Batch Prefill Activation graphs ─────────────────────────────────────────────────────
@@ -164,41 +152,58 @@ private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatch
164152 .persistOnDevice (state .wrapX , state .wrapKeyCache , state .wrapValueCache );
165153 }
166154
167- // ── Static factory ────────────────────────────────────────────────────────
168-
169155 /**
170- * Creates, JIT-compiles, and warms up the unified plan.
171- * Mirrors {@link TornadoVMMasterPlan#initializeTornadoVMPlan}.
156+ * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill in batches and separated decode*.
172157 */
173- public static TornadoVMMasterPlanWithBatchPrefillDecode initializeUnifiedPlan (
174- LlamaState state , Model model , int batchSize ) {
158+ @ Override
159+ public TornadoExecutionPlan createExecutionPlan () {
160+ LlamaTornadoWeights weights = (LlamaTornadoWeights ) model .weights ();
161+ SchedulerType schedulerType = SchedulerDetectionService .determineSchedulerType (model );
175162
176- long t0 = System .nanoTime ();
177- TornadoVMMasterPlanWithBatchPrefillDecode plan =
178- new TornadoVMMasterPlanWithBatchPrefillDecode (state , model , batchSize );
163+ List <ImmutableTaskGraph > all = new ArrayList <>(2 * N + 3 );
179164
180- if (ENABLE_TIMING )
181- System .err .printf ("[BatchPlan] Graph construction: %.2f ms%n" ,
182- (System .nanoTime () - t0 ) / 1e6 );
165+ // [0] Batch prefill activation ────────────────────────────────────────────────
166+ KernelContext batchActCtx = new KernelContext ();
167+ all .add (buildBatchPrefillActivationGraph (batchActCtx ).snapshot ());
168+ gridScheduler .addWorkerGrid ("batchActivation.batchUpdateX" ,
169+ WorkerGridFactory .genericWorker (batchSize * config .dim (), 128 ));
183170
184- if (CUDA_GRAPHS ) plan .executionPlan .withAllGraphs ().withCUDAGraph ();
185- plan .executionPlan .withPreCompilation ();
171+ // [1..N] Batch prefill layer graphs ───────────────────────────────────────────
172+ LlamaFP16LayersBatchPrefill batchLayers =
173+ new LlamaFP16LayersBatchPrefill (state , weights , config , batchSize );
174+ all .addAll (batchLayers .getLayerImmutableTaskGraphs ());
175+ batchLayers .updateGridScheduler (gridScheduler );
186176
187- if (ENABLE_TIMING )
188- System .err .printf ("[BatchPlan] JIT compilation: %.2f ms%n" ,
189- (System .nanoTime () - t0 ) / 1e6 );
177+ // [N+1] Decode activation (with KV-cache pass-through) ────────────────
178+ KernelContext decodeActCtx = new KernelContext ();
179+ all .add (buildDecodeActivationGraph (decodeActCtx , batchLayers .getLastLayerTaskGraphID ()).snapshot ());
180+ gridScheduler .addWorkerGrid ("decodeActivationUpdate.updateX" ,
181+ WorkerGridFactory .genericWorker (config .dim (), 128 ));
190182
191- plan .forceCopyInReadOnlyData ();
183+ // [N+2..2N+1] Decode layer graphs ────────────────────────────────────
184+ // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload).
185+ LlamaFP16FFNLayersDecode decodeLayers =
186+ new LlamaFP16FFNLayersDecode (
187+ "llamaFFNDecode" , state , weights , config , schedulerType );
188+ all .addAll (decodeLayers .getFFNLayerImmutableTaskGraphs ());
189+ decodeLayers .updateGridScheduler (gridScheduler );
192190
193- if (ENABLE_TIMING )
194- System .err .printf ("[BatchPlan] Init complete: %.2f ms%n" ,
195- (System .nanoTime () - t0 ) / 1e6 );
191+ // [2N+2] Logits ───────────────────────────────────────────────────────
192+ // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache)
193+ // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the
194+ // KV-cache pointer survives the logits → decode-activation boundary across tokens.
195+ LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode ("logits" , state , weights , config ,
196+ decodeLayers .getLastFFNLayerTaskGraphID (), schedulerType );
197+ all .add (logitsLayer .getImmutableTaskGraph ());
198+ logitsLayer .updateGridScheduler (gridScheduler );
196199
197- return plan ;
200+ return new TornadoExecutionPlan ( all . toArray ( new ImmutableTaskGraph [ 0 ])) ;
198201 }
199202
203+
200204 /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */
201- private void forceCopyInReadOnlyData () {
205+ @ Override
206+ public void forceCopyInReadOnlyData () {
202207 state .wrapXBatch .clear ();
203208 state .wrapX .clear ();
204209 state .positionHolder .init (0 );
0 commit comments