Skip to content

BF16 dispatch chain (Phase 3/3): wire Bf16TensorData dispatch in DefaultCpuOpsJvm#614

Merged
michalharakal merged 1 commit into
developfrom
feature/bf16-dispatch
May 17, 2026
Merged

BF16 dispatch chain (Phase 3/3): wire Bf16TensorData dispatch in DefaultCpuOpsJvm#614
michalharakal merged 1 commit into
developfrom
feature/bf16-dispatch

Conversation

@michalharakal

Copy link
Copy Markdown
Contributor

Resolves #613. Final phase of the three-phase BF16 dispatch chain. Follows #610 (Bf16TensorData, merged) and #612 (loader KEEP_NATIVE policy, merged).

What

DefaultCpuOpsJvm.chooseQuantizedMatmul now recognises Bf16TensorData as a weight type and dispatches through the Bf16MatmulKernel SPI registered in #605. After this PR, a consumer's complete opt-in is one line:

val loader = SafeTensorsParametersLoader(
    sourceProvider = { ... },
    bf16Policy = Bf16LoadPolicy.KEEP_NATIVE,
)
// loader emits Bf16DenseTensorData for BF16 weights
// → ops.matmul automatically picks up native / Panama / scalar BF16 kernel

Native FFM wins (priority 100) when libskainet_kernels.so is bundled and loaded; Panama Vector (50) is the JVM-pure SIMD fallback; scalar (0) is the floor.

Implementation

change shape
private val bf16MatmulKernel: Bf16MatmulKernel by lazy { ... } non-null, ScalarBf16MatmulKernel floor. Mirrors fp32MatmulKernel, not the nullable q4kMatmulKernel / q8_0MatmulKernel — those carry a legacy JvmQuantizedVectorKernels fallback; BF16 has no such legacy because pre-this-chain BF16 weights got dequanted before reaching matmul.
is Bf16TensorData -> branch in chooseQuantizedMatmul One SGEMM call per ops.matmul invocation — no per-batch matvec loop, because the BF16 SPI is a full (m, n, k) dense SGEMM. The B byte-stride is outputDim * Bf16TensorData.BYTES_PER_ELEMENT.

Tests

3 integration tests (Bf16MatmulDispatchTest):

  • single_batch_matmul_against_bf16_weight_routes_correctly[1, 128] × [128, 64] matmul; output matches scalar reference within 1e-2 × k.
  • multi_batch_matmul_against_bf16_weight_routes_correctlym=3, exercises a 2D output (3 rows of activations against the same weight).
  • llm_typical_attention_proj_matmul_routes_correctly — 512² attention-projection shape.

Each compares ops.matmul(input, bf16Weight) against a ScalarBf16MatmulKernel reference. This is integration coverage — kernel correctness was locked down by the parity tests in #605.

Full :skainet-backends:skainet-backend-cpu:jvmTest and :skainet-backends:skainet-backend-native-cpu:jvmTest suites pass on linux-x86_64 / JDK 21 with --add-modules jdk.incubator.vector.

Chain complete

phase issue PR state
1: Bf16TensorData #609 #610 merged
2: Loader KEEP_NATIVE policy #611 #612 merged
3: Dispatch wiring (this PR) #613 this open

After this lands, the BF16 path is end-to-end reachable from a consumer model. The natural next step is a Gemma-3n smoke test in SKaiNET-transformers/llm-inference/gemma — out of scope here.

🤖 Generated with Claude Code

…613)

Final phase of the three-phase BF16 dispatch chain. Follow-ups to
#610 (Bf16TensorData) and #612 (loader KEEP_NATIVE policy) — both
merged.

After this PR, a consumer that flips `bf16Policy = KEEP_NATIVE` on
SafeTensorsParametersLoader (or constructs a `Bf16DenseTensorData`
directly) gets the SIMD-vectorised BF16 matmul path with zero other
code changes. Native FFM kernel (priority 100) wins when the bundled
libskainet_kernels.so is loaded; falls through to Panama Vector (50)
then to the scalar SPI reference (0).

Implementation:
  - New `bf16MatmulKernel: Bf16MatmulKernel` lazy in DefaultCpuOpsJvm.
    Non-null with `ScalarBf16MatmulKernel` floor — mirrors
    `fp32MatmulKernel`'s pattern rather than the nullable
    `q4kMatmulKernel` / `q8_0MatmulKernel` pattern (which exist because
    Q4_K / Q8_0 have legacy non-SPI fallbacks via
    `JvmQuantizedVectorKernels`; BF16 has no such legacy).
  - New `is Bf16TensorData ->` branch in `chooseQuantizedMatmul`'s
    `when (bData)` block. The BF16 SPI kernel is a full SGEMM
    `(m, n, k)` with byte-strides on the B operand — no per-batch
    matvec loop like Q4_K/Q8_0/Q6_K need.

3 integration tests in `Bf16MatmulDispatchTest`:
  - single-batch matmul (`[1, k] × [k, n]` BF16) matches scalar
    reference within `1e-2 * k`.
  - multi-batch matmul (`m=3, k=256, n=32`) — exercises a 2D output.
  - LLM-typical 512² attention projection.

Refs #613. Full `:skainet-backends:skainet-backend-cpu:jvmTest` and
`:skainet-backends:skainet-backend-native-cpu:jvmTest` suites pass on
linux-x86_64 / JDK 21 with `--add-modules jdk.incubator.vector`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit aadb8b2 into develop May 17, 2026
6 of 8 checks passed
@michalharakal michalharakal deleted the feature/bf16-dispatch branch May 17, 2026 10:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BF16 dispatch chain (Phase 3/3): wire Bf16TensorData dispatch in DefaultCpuOpsJvm.chooseQuantizedMatmul

1 participant