BF16 dispatch chain (Phase 3/3): wire Bf16TensorData dispatch in DefaultCpuOpsJvm#614
Merged
Merged
Conversation
…613) Final phase of the three-phase BF16 dispatch chain. Follow-ups to #610 (Bf16TensorData) and #612 (loader KEEP_NATIVE policy) — both merged. After this PR, a consumer that flips `bf16Policy = KEEP_NATIVE` on SafeTensorsParametersLoader (or constructs a `Bf16DenseTensorData` directly) gets the SIMD-vectorised BF16 matmul path with zero other code changes. Native FFM kernel (priority 100) wins when the bundled libskainet_kernels.so is loaded; falls through to Panama Vector (50) then to the scalar SPI reference (0). Implementation: - New `bf16MatmulKernel: Bf16MatmulKernel` lazy in DefaultCpuOpsJvm. Non-null with `ScalarBf16MatmulKernel` floor — mirrors `fp32MatmulKernel`'s pattern rather than the nullable `q4kMatmulKernel` / `q8_0MatmulKernel` pattern (which exist because Q4_K / Q8_0 have legacy non-SPI fallbacks via `JvmQuantizedVectorKernels`; BF16 has no such legacy). - New `is Bf16TensorData ->` branch in `chooseQuantizedMatmul`'s `when (bData)` block. The BF16 SPI kernel is a full SGEMM `(m, n, k)` with byte-strides on the B operand — no per-batch matvec loop like Q4_K/Q8_0/Q6_K need. 3 integration tests in `Bf16MatmulDispatchTest`: - single-batch matmul (`[1, k] × [k, n]` BF16) matches scalar reference within `1e-2 * k`. - multi-batch matmul (`m=3, k=256, n=32`) — exercises a 2D output. - LLM-typical 512² attention projection. Refs #613. Full `:skainet-backends:skainet-backend-cpu:jvmTest` and `:skainet-backends:skainet-backend-native-cpu:jvmTest` suites pass on linux-x86_64 / JDK 21 with `--add-modules jdk.incubator.vector`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Resolves #613. Final phase of the three-phase BF16 dispatch chain. Follows #610 (Bf16TensorData, merged) and #612 (loader KEEP_NATIVE policy, merged).
What
DefaultCpuOpsJvm.chooseQuantizedMatmulnow recognisesBf16TensorDataas a weight type and dispatches through theBf16MatmulKernelSPI registered in #605. After this PR, a consumer's complete opt-in is one line:Native FFM wins (priority 100) when
libskainet_kernels.sois bundled and loaded; Panama Vector (50) is the JVM-pure SIMD fallback; scalar (0) is the floor.Implementation
private val bf16MatmulKernel: Bf16MatmulKernel by lazy { ... }ScalarBf16MatmulKernelfloor. Mirrorsfp32MatmulKernel, not the nullableq4kMatmulKernel/q8_0MatmulKernel— those carry a legacyJvmQuantizedVectorKernelsfallback; BF16 has no such legacy because pre-this-chain BF16 weights got dequanted before reachingmatmul.is Bf16TensorData ->branch inchooseQuantizedMatmulops.matmulinvocation — no per-batch matvec loop, because the BF16 SPI is a full(m, n, k)dense SGEMM. The B byte-stride isoutputDim * Bf16TensorData.BYTES_PER_ELEMENT.Tests
3 integration tests (
Bf16MatmulDispatchTest):single_batch_matmul_against_bf16_weight_routes_correctly—[1, 128] × [128, 64]matmul; output matches scalar reference within1e-2 × k.multi_batch_matmul_against_bf16_weight_routes_correctly—m=3, exercises a 2D output (3 rows of activations against the same weight).llm_typical_attention_proj_matmul_routes_correctly— 512² attention-projection shape.Each compares
ops.matmul(input, bf16Weight)against aScalarBf16MatmulKernelreference. This is integration coverage — kernel correctness was locked down by the parity tests in #605.Full
:skainet-backends:skainet-backend-cpu:jvmTestand:skainet-backends:skainet-backend-native-cpu:jvmTestsuites pass on linux-x86_64 / JDK 21 with--add-modules jdk.incubator.vector.Chain complete
After this lands, the BF16 path is end-to-end reachable from a consumer model. The natural next step is a Gemma-3n smoke test in
SKaiNET-transformers/llm-inference/gemma— out of scope here.🤖 Generated with Claude Code