Skip to content

BF16 dispatch chain (Phase 2/3): opt-in SafeTensors loader policy to keep BF16 native #611

@michalharakal

Description

@michalharakal

Phase 2 of the BF16 dispatch chain. Follow-up to #610 (Bf16TensorData merged).

Why

Bf16TensorData / Bf16DenseTensorData exist now but no SafeTensors-loaded tensor uses them — today SafeTensorsParametersLoader (line 83-89 in skainet-io-safetensors/.../SafeTensorsParametersLoader.kt) unconditionally dequants BF16 bytes to FP32 via dequantBF16(bytes) and wraps the result as a FloatArrayTensorData. So even with #610 in develop, BF16 weights still arrive as FP32 and the new SIMD Bf16MatmulKernel (#605) can't be reached.

This PR adds an opt-in knob so consumers that want native BF16 (Gemma-3n is the obvious first one) get a Bf16DenseTensorData-backed tensor; everyone else stays on the existing FP32 dequant path with zero behavioral change.

Scope

  1. New Bf16LoadPolicy enum (commonMain in skainet-io-safetensors):
    • DEQUANT_TO_FP32 (default) — current behavior.
    • KEEP_NATIVE — wrap the raw bytes in Bf16DenseTensorData.
  2. New constructor parameter bf16Policy: Bf16LoadPolicy = DEQUANT_TO_FP32 on SafeTensorsParametersLoader. Default preserves source + bytecode compat for existing Kotlin callers; Java callers can still hit the 2-arg constructor via positional args (no Java consumer of this class in-tree).
  3. Branch in the DataType.BFLOAT16 case: when policy is KEEP_NATIVE, skip dequantBF16(bytes) and emit Bf16DenseTensorData(shape, bytes) wrapped via ctx.fromData(...).
  4. The require(dtype == FP32::class) check stays — the consumer-visible dtype is still FP32 (Bf16TensorData : TensorData<DType, Float>); only the underlying storage is BF16. Same pattern as Q4_K / Q8_0 tensors which have FP32 dtype with quantized storage.

Tests

New tests in commonTest:

  • Round-trip a small BF16 SafeTensors blob through the loader under both policies.
  • DEQUANT_TO_FP32 path produces a tensor whose tensor.data is FloatArrayTensorData and values match a hand-computed FP32 reference within 1e-2.
  • KEEP_NATIVE path produces a tensor whose tensor.data is Bf16DenseTensorData, the packedData byte array matches the original SafeTensors bytes verbatim, and get(*indices) decodes to the same values as the DEQUANT path within 1e-3 (just FMA noise; both go through the same bf16_bits << 16 math).
  • Mixed-dtype SafeTensors blob (FP32 + BF16 in the same file) loaded with KEEP_NATIVE policy — FP32 tensors stay FloatArrayTensorData, BF16 tensors become Bf16DenseTensorData. Confirms the policy only affects the BF16 branch.

Phase 3 (next follow-up)

DefaultCpuOpsJvm.chooseQuantizedMatmul dispatch on is Bf16TensorData -> via the SPI Bf16MatmulKernel. Mirrors #608's Q8_0 wiring. After Phase 3 lands, a Gemma-3n consumer can flip Bf16LoadPolicy.KEEP_NATIVE and immediately get the SIMD-vectorized BF16 matmul path with zero other code changes.

Branch

feature/bf16-loader-policy off develop.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions