Skip to content

Commit aa53ebe

Browse files
[prf/dec][doc] Update javadoc to reflect unified batched prefill-decode plan
1 parent 97f2d8b commit aa53ebe

3 files changed

Lines changed: 6 additions & 8 deletions

File tree

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
1010

1111
/**
12-
* Decode-path FFN layers for the Phase 4 unified plan.
12+
* Decode FFN layers of the unified batched prefill-decode plan
13+
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
1314
*
1415
* <p>Overrides data-transfer declarations so that all cross-graph boundaries use
1516
* the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by
@@ -20,11 +21,6 @@
2021
* never propagated — causing either a null-pointer crash or a silent re-upload
2122
* from host (zeros), corrupting the hidden state and KV cache.</p>
2223
*
23-
* <p>Two boundaries are fixed here:</p>
24-
* <ul>
25-
* <li>{@code wrapX}: via {@link #predecessorGraphName} hook in the base class.</li>
26-
* <li>All other consumed objects: via the {@link #configureLayerDataTransfers} override.</li>
27-
* </ul>
2824
*/
2925
public class LlamaFP16FFNLayersDecode extends LlamaFP16FFNLayers {
3026
public LlamaFP16FFNLayersDecode(String taskGraph, LlamaState state,

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import uk.ac.manchester.tornado.api.TaskGraph;
99

1010
/**
11-
* Logits layer for the unified prefill-decode plan (Phase 4).
11+
* Logits layer of the unified batched prefill-decode plan
12+
* * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
1213
*
1314
* <p>Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device
1415
* pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import java.util.stream.IntStream;
1717

1818
/**
19-
* Builds per-layer batch prefill TaskGraphs for Phase 4 GPU batched prefill.
19+
* Prefill FFN layers with batching for the unified batched prefill-decode plan
20+
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
2021
*
2122
* <p>One {@link ImmutableTaskGraph} per transformer layer, each processing
2223
* {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.</p>

0 commit comments

Comments
 (0)