Skip to content

BF16 dispatch chain (Phase 2/3): opt-in SafeTensors loader policy to keep BF16 native#612

Merged
michalharakal merged 1 commit into
developfrom
feature/bf16-loader-policy
May 17, 2026
Merged

BF16 dispatch chain (Phase 2/3): opt-in SafeTensors loader policy to keep BF16 native#612
michalharakal merged 1 commit into
developfrom
feature/bf16-loader-policy

Conversation

@michalharakal

Copy link
Copy Markdown
Contributor

Resolves #611. Phase 2 of the BF16 dispatch chain. Follows #610 (Bf16TensorData, merged).

What

SafeTensorsParametersLoader gains a bf16Policy: Bf16LoadPolicy = DEQUANT_TO_FP32 constructor parameter. Default preserves today's behavior verbatim; opt-in KEEP_NATIVE skips the dequant pass and emits Bf16DenseTensorData (added in #610) instead.

policy what the loader produces for a BF16 tensor memory cost matmul path
DEQUANT_TO_FP32 (default) FloatArrayTensorData<FP32> with dequanted values 2× on-disk regular FP32 matmul
KEEP_NATIVE (new) Bf16DenseTensorData with packed bytes 1× on-disk (after Phase 3) routes via Bf16MatmulKernel SPI

Why this is safe

  • No existing consumer calls the loader with a non-default policy. The default parameter value preserves source + bytecode compat for every current Kotlin caller. No Java consumers of this class in-tree.
  • The consumer-visible dtype stays FP32::class for both policies — same convention as Q4_K / Q8_0 tensors which have FP32 dtype but quantised storage. Pattern-matching on tensor.data is Bf16TensorData is the recognition surface for the upcoming dispatch.
  • BF16 → FP32 conversion is bit-shift only (float_bits = (bf16 & 0xFFFF) shl 16); the two policies apply that math at different points (load-time vs read-time), but the resulting FP32 values are bit-identical. Locked down by a dedicated test.

Tests

4 new tests in commonTest (SafeTensorsParametersLoaderBf16PolicyTest):

  • bf16_default_policy_dequants_to_fp32_floatArray — confirms the default path still produces FloatArrayTensorData with correct values (within BF16 precision, 1e-2 abs).
  • bf16_keep_native_policy_emits_bf16DenseTensorData — confirms KEEP_NATIVE produces Bf16DenseTensorData, packedData matches the on-disk bytes byte-for-byte, and get()-decoded values still match the FP32 source within BF16 precision.
  • bf16_keep_native_decoded_values_match_dequant_path_exactly — bit-identity check across the two policies. Same bf16_bits << 16 math applied at different times; outputs must agree to the raw bits.
  • mixed_bf16_fp32_file_keep_native_only_affects_bf16 — loads a SafeTensors file with one BF16 and one FP32 tensor under KEEP_NATIVE; confirms only the BF16 path is affected (FP32 tensor still arrives as FloatArrayTensorData with exact values).

Full :skainet-io:skainet-io-safetensors:jvmTest suite passes on linux-x86_64 / JDK 21.

Phase 3 (next follow-up)

DefaultCpuOpsJvm.chooseQuantizedMatmul dispatch on is Bf16TensorData -> via the SPI Bf16MatmulKernel. Mirrors #608's Q8_0 wiring. After Phase 3 lands, flipping bf16Policy = KEEP_NATIVE is the only consumer-side change needed to opt into the SIMD-vectorised BF16 matmul.

🤖 Generated with Claude Code

Phase 2 of the three-phase BF16 dispatch chain. Follow-up to #610
(Bf16TensorData merged).

Today's loader unconditionally dequants BF16 → FP32 at load
(`SafeTensorsParametersLoader.kt` line 87, `dequantBF16(bytes)`),
which means even with the new Bf16TensorData type in place no
SafeTensors-loaded weight ever reaches it. This PR adds an opt-in
policy so consumers that want native BF16 (Gemma-3n is the obvious
first one) get a Bf16DenseTensorData-backed tensor; everyone else
stays on the dequant path with zero behavioural change.

Surface:
  - new `Bf16LoadPolicy` enum (commonMain) with `DEQUANT_TO_FP32`
    (default) and `KEEP_NATIVE` cases. Documents the trade-off:
    memory halved vs. per-element decode cost on non-matmul ops.
  - new constructor parameter `bf16Policy: Bf16LoadPolicy =
    Bf16LoadPolicy.DEQUANT_TO_FP32`. Default preserves source +
    bytecode compat for every existing Kotlin caller.
  - new branch in the `DataType.BFLOAT16` case: when policy is
    `KEEP_NATIVE`, wrap the on-disk bytes in `Bf16DenseTensorData`
    and emit via `ctx.fromData(...)`. The consumer-visible dtype
    stays `FP32::class` (same pattern as Q4_K / Q8_0 tensors —
    quantised storage, FP32 dtype tag); only `tensor.data` differs.

4 new tests in `commonTest`:
  - DEQUANT_TO_FP32 path produces FloatArrayTensorData with values
    within BF16 precision.
  - KEEP_NATIVE path produces Bf16DenseTensorData whose packedData
    byte array matches the on-disk bytes verbatim.
  - Decoded values from both paths are bit-identical (both apply
    the same `bf16_bits << 16` math; only WHEN differs).
  - Mixed BF16+FP32 file under KEEP_NATIVE — BF16 becomes
    Bf16DenseTensorData, FP32 stays FloatArrayTensorData.

Refs #611. Full `:skainet-io:skainet-io-safetensors:jvmTest` suite
passes on linux-x86_64 / JDK 21.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit 4e55ded into develop May 17, 2026
5 of 6 checks passed
@michalharakal michalharakal deleted the feature/bf16-loader-policy branch May 17, 2026 10:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BF16 dispatch chain (Phase 2/3): opt-in SafeTensors loader policy to keep BF16 native

1 participant