Phase 2 of the BF16 dispatch chain. Follow-up to #610 (Bf16TensorData merged).
Why
Bf16TensorData / Bf16DenseTensorData exist now but no SafeTensors-loaded tensor uses them — today SafeTensorsParametersLoader (line 83-89 in skainet-io-safetensors/.../SafeTensorsParametersLoader.kt) unconditionally dequants BF16 bytes to FP32 via dequantBF16(bytes) and wraps the result as a FloatArrayTensorData. So even with #610 in develop, BF16 weights still arrive as FP32 and the new SIMD Bf16MatmulKernel (#605) can't be reached.
This PR adds an opt-in knob so consumers that want native BF16 (Gemma-3n is the obvious first one) get a Bf16DenseTensorData-backed tensor; everyone else stays on the existing FP32 dequant path with zero behavioral change.
Scope
- New
Bf16LoadPolicy enum (commonMain in skainet-io-safetensors):
DEQUANT_TO_FP32 (default) — current behavior.
KEEP_NATIVE — wrap the raw bytes in Bf16DenseTensorData.
- New constructor parameter
bf16Policy: Bf16LoadPolicy = DEQUANT_TO_FP32 on SafeTensorsParametersLoader. Default preserves source + bytecode compat for existing Kotlin callers; Java callers can still hit the 2-arg constructor via positional args (no Java consumer of this class in-tree).
- Branch in the
DataType.BFLOAT16 case: when policy is KEEP_NATIVE, skip dequantBF16(bytes) and emit Bf16DenseTensorData(shape, bytes) wrapped via ctx.fromData(...).
- The
require(dtype == FP32::class) check stays — the consumer-visible dtype is still FP32 (Bf16TensorData : TensorData<DType, Float>); only the underlying storage is BF16. Same pattern as Q4_K / Q8_0 tensors which have FP32 dtype with quantized storage.
Tests
New tests in commonTest:
- Round-trip a small BF16 SafeTensors blob through the loader under both policies.
DEQUANT_TO_FP32 path produces a tensor whose tensor.data is FloatArrayTensorData and values match a hand-computed FP32 reference within 1e-2.
KEEP_NATIVE path produces a tensor whose tensor.data is Bf16DenseTensorData, the packedData byte array matches the original SafeTensors bytes verbatim, and get(*indices) decodes to the same values as the DEQUANT path within 1e-3 (just FMA noise; both go through the same bf16_bits << 16 math).
- Mixed-dtype SafeTensors blob (FP32 + BF16 in the same file) loaded with
KEEP_NATIVE policy — FP32 tensors stay FloatArrayTensorData, BF16 tensors become Bf16DenseTensorData. Confirms the policy only affects the BF16 branch.
Phase 3 (next follow-up)
DefaultCpuOpsJvm.chooseQuantizedMatmul dispatch on is Bf16TensorData -> via the SPI Bf16MatmulKernel. Mirrors #608's Q8_0 wiring. After Phase 3 lands, a Gemma-3n consumer can flip Bf16LoadPolicy.KEEP_NATIVE and immediately get the SIMD-vectorized BF16 matmul path with zero other code changes.
Branch
feature/bf16-loader-policy off develop.
Phase 2 of the BF16 dispatch chain. Follow-up to #610 (Bf16TensorData merged).
Why
Bf16TensorData/Bf16DenseTensorDataexist now but no SafeTensors-loaded tensor uses them — todaySafeTensorsParametersLoader(line 83-89 inskainet-io-safetensors/.../SafeTensorsParametersLoader.kt) unconditionally dequants BF16 bytes to FP32 viadequantBF16(bytes)and wraps the result as aFloatArrayTensorData. So even with #610 in develop, BF16 weights still arrive as FP32 and the new SIMDBf16MatmulKernel(#605) can't be reached.This PR adds an opt-in knob so consumers that want native BF16 (Gemma-3n is the obvious first one) get a
Bf16DenseTensorData-backed tensor; everyone else stays on the existing FP32 dequant path with zero behavioral change.Scope
Bf16LoadPolicyenum (commonMain inskainet-io-safetensors):DEQUANT_TO_FP32(default) — current behavior.KEEP_NATIVE— wrap the raw bytes inBf16DenseTensorData.bf16Policy: Bf16LoadPolicy = DEQUANT_TO_FP32onSafeTensorsParametersLoader. Default preserves source + bytecode compat for existing Kotlin callers; Java callers can still hit the 2-arg constructor via positional args (no Java consumer of this class in-tree).DataType.BFLOAT16case: when policy isKEEP_NATIVE, skipdequantBF16(bytes)and emitBf16DenseTensorData(shape, bytes)wrapped viactx.fromData(...).require(dtype == FP32::class)check stays — the consumer-visible dtype is still FP32 (Bf16TensorData : TensorData<DType, Float>); only the underlying storage is BF16. Same pattern as Q4_K / Q8_0 tensors which have FP32 dtype with quantized storage.Tests
New tests in
commonTest:DEQUANT_TO_FP32path produces a tensor whosetensor.data is FloatArrayTensorDataand values match a hand-computed FP32 reference within 1e-2.KEEP_NATIVEpath produces a tensor whosetensor.data is Bf16DenseTensorData, thepackedDatabyte array matches the original SafeTensors bytes verbatim, andget(*indices)decodes to the same values as the DEQUANT path within 1e-3 (just FMA noise; both go through the samebf16_bits << 16math).KEEP_NATIVEpolicy — FP32 tensors stayFloatArrayTensorData, BF16 tensors becomeBf16DenseTensorData. Confirms the policy only affects the BF16 branch.Phase 3 (next follow-up)
DefaultCpuOpsJvm.chooseQuantizedMatmuldispatch onis Bf16TensorData ->via the SPIBf16MatmulKernel. Mirrors #608's Q8_0 wiring. After Phase 3 lands, a Gemma-3n consumer can flipBf16LoadPolicy.KEEP_NATIVEand immediately get the SIMD-vectorized BF16 matmul path with zero other code changes.Branch
feature/bf16-loader-policyoffdevelop.