Copy-in embeddings in reduced precision and handle precision conversion during inference#73
Copy-in embeddings in reduced precision and handle precision conversion during inference#73mikepapadim wants to merge 3 commits intomainfrom
Conversation
…n in TornadoVM acceleration.
…el loaders for consistent tensor loading.
There was a problem hiding this comment.
Pull request overview
This PR introduces precision-aware embedding handling by storing embeddings in FP16 format and performing on-the-fly conversion to FP32 during GPU inference, reducing memory footprint while maintaining computational accuracy.
- Embeddings are now stored in reduced precision (FP16) and converted to FP32 on GPU during inference
- Introduced new TornadoVM conversion kernels (
convertFP16toFP32andconvertFP32toFP16) for efficient precision conversion - Refactored tensor loading to support both standard and TornadoVM-compatible memory layouts with proper header handling
Reviewed changes
Copilot reviewed 23 out of 23 changed files in this pull request and generated 24 comments.
Show a summary per file
| File | Description |
|---|---|
| src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java | Updated activation layer to use FP16→FP32 conversion kernel instead of empty copy task |
| src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java | Added GPU kernels for FP16↔FP32 precision conversion |
| src/main/java/org/beehive/gpullama3/tensor/tornado/TornadoTensor.java | Removed size tracking from base class (moved to subclasses), added spacing to documentation |
| src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java | Moved size field from parent to this class, added TODO for Q8_0 loading fix |
| src/main/java/org/beehive/gpullama3/tensor/tornado/FP32TornadoTensor.java | Renamed field to tornadoNativeArray, added factory method for memory segment loading |
| src/main/java/org/beehive/gpullama3/tensor/tornado/FP16TornadoTensor.java | Renamed field to tornadoNativeArray, added factory method for memory segment loading |
| src/main/java/org/beehive/gpullama3/tensor/GGUF.java | Split tensor loading into loadTensorsStandard and loadTensorsTornado with proper memory layout handling |
| src/main/java/org/beehive/gpullama3/model/loader/*.java | Updated model loaders to remove loadWeights parameter and use loadTornadoTensor instead of loadTornadoTensorAsFP32 |
| src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java | Updated to always load weights, replaced manual FP32 conversion with GPU-based approach |
| src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java | Removed loadWeights parameter, added logic to select tensor loading method based on TornadoVM usage |
| src/main/java/org/beehive/gpullama3/model/ModelType.java | Updated all model type loaders to remove loadWeights parameter |
| src/main/java/org/beehive/gpullama3/inference/state/*.java | Added embeddingX field (HalfFloatArray) to all state classes for FP16 embedding storage |
| src/main/java/org/beehive/gpullama3/inference/InferenceCore.java | Modified embedding copy to use FP16 format (Short.BYTES) targeting state.embeddingX |
| set_paths | Commented out TornadoVM path configuration variables |
| .github/workflows/build-and-run.yml | Disabled Spotless formatting check, switched TornadoVM clone to develop branch |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| System.out.println("[GGUF] fileChannel = FileChannel.open(modelPath, READ, WRITE);"); | ||
| fileChannel = FileChannel.open(modelPath, READ, WRITE); |
There was a problem hiding this comment.
File opened with WRITE permission unnecessarily. The file is opened with READ and WRITE permissions, but based on the method name and usage, only read access should be required for loading metadata. Opening with write permissions creates an unnecessary security risk. Consider using only READ permission unless write access is actually needed.
| System.out.println("[GGUF] fileChannel = FileChannel.open(modelPath, READ, WRITE);"); | |
| fileChannel = FileChannel.open(modelPath, READ, WRITE); | |
| System.out.println("[GGUF] fileChannel = FileChannel.open(modelPath, READ);"); | |
| fileChannel = FileChannel.open(modelPath, READ); |
| * Loads GGUF tensor data using a TornadoVM-compatible memory layout. | ||
| * | ||
| * <p>This method parses the GGUF tensor list and memory-maps each tensor | ||
| * in {@link TornadoNativeArray} layout directly from the underlying{@link FileChannel}. |
There was a problem hiding this comment.
Missing space in JavaDoc. There should be a space between 'underlying' and the opening brace.
| * in {@link TornadoNativeArray} layout directly from the underlying{@link FileChannel}. | |
| * in {@link TornadoNativeArray} layout directly from the underlying {@link FileChannel}. |
| ); | ||
|
|
||
| } | ||
| // @formatter:off |
There was a problem hiding this comment.
Incorrect formatter comment. This should be @formatter:on to match the opening @formatter:off comment, not @formatter:off.
| // @formatter:off | |
| // @formatter:on |
| : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); | ||
| return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); | ||
| } | ||
| // @formatter:off |
There was a problem hiding this comment.
Incorrect formatter comment. This should be @formatter:on to match the opening @formatter:off comment, not @formatter:off.
| // @formatter:off | |
| // @formatter:on |
| ); | ||
| return config; | ||
| } | ||
| // @formatter:off |
There was a problem hiding this comment.
Incorrect formatter comment. This should be @formatter:on to match the opening @formatter:off comment, not @formatter:off.
| // @formatter:off | |
| // @formatter:on |
| run: | | ||
| cd ${{ github.workspace }} | ||
| ./mvnw -T12C -Pspotless spotless:check | ||
| #./mvnw -T12C -Pspotless spotless:check |
There was a problem hiding this comment.
Code formatting check is commented out. The Spotless code formatting check has been commented out, which means code style violations will not be caught in CI. This should be re-enabled to maintain code quality standards, or removed entirely if no longer needed.
| #./mvnw -T12C -Pspotless spotless:check | |
| ./mvnw -T12C -Pspotless spotless:check |
| * Load a tensor and manually convert to FP32 (FloatArray). | ||
| * Used for embeddings that currently are treated as FP32. | ||
| * TODO: it is ultra-slow and should be removed |
There was a problem hiding this comment.
Outdated documentation. The JavaDoc states "Used for embeddings that currently are treated as FP32" and includes a TODO about removal, but based on the PR changes, embeddings are now handled as FP16 with GPU-based conversion. The documentation should be updated to clarify the current usage or deprecation status of this method.
| * Load a tensor and manually convert to FP32 (FloatArray). | |
| * Used for embeddings that currently are treated as FP32. | |
| * TODO: it is ultra-slow and should be removed | |
| * [DEPRECATED] Loads a tensor and manually converts it to FP32 (FloatArray). | |
| * Embeddings are now handled as FP16 with GPU-based conversion; this method remains only for legacy compatibility. | |
| * This method is ultra-slow and should not be used in new code. Scheduled for removal. |
|
|
||
| public class FP32TornadoTensor extends TornadoTensor { | ||
| private final FloatArray values; | ||
| private final FloatArray tornadoNativeArray; |
There was a problem hiding this comment.
Renamed field lacks consistency. The field was renamed from values to tornadoNativeArray, but this naming is inconsistent with the sibling class FP16TornadoTensor which also uses tornadoNativeArray. While the naming is now consistent between the two classes, a more descriptive name like data or array would be clearer since the "Tornado" context is already evident from the class name.
| fields.wrapK = new FloatArray(nEmbdKGqa); | ||
| fields.wrapV = new FloatArray(nEmbdKGqa); | ||
|
|
||
| fields.embeddingX = new HalfFloatArray(config.dim()); |
There was a problem hiding this comment.
Duplicate field initialization. The field fields.embeddingX is initialized twice - once at line 70 and again at line 80 with the same value. The second initialization should be removed.
| fields.embeddingX = new HalfFloatArray(config.dim()); | |
| // Removed duplicate initialization of fields.embeddingX |
| wrapX.set(i, x.get(i).getFloat32()); | ||
| } | ||
|
|
||
| public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) { |
There was a problem hiding this comment.
Extra whitespace after parameter. There are two spaces between the comma and FloatArray - should be one space.
| public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) { | |
| public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) { |
20b0d4c to
937408f
Compare
|
closed in favour of #78 |
No description provided.