Skip to content

Commit 045cb20

Browse files
michalharakalclaude
andcommitted
Add SafeTensorsParametersLoader.withPolicy(DTypePolicy) (W0b of #615)
Companion factory + adapter that lets callers drive the loader by the generalised DTypePolicy from W1 instead of the BF16-specific Bf16LoadPolicy enum. Existing callers keep working unchanged — the factory is additive. Policy → behaviour mapping covers all four DTypePolicy arms and follows the RFC semantics: - Any -> DEQUANT_TO_FP32 (adaptive default) - Require(BF16) -> KEEP_NATIVE - Require(FP32) -> DEQUANT_TO_FP32 - Require(FP16) -> throws (no Fp16DenseTensorData backing yet) - Require(other) -> throws (loader can't fabricate dtypes) - Prefer(BF16) -> KEEP_NATIVE; Prefer(other) -> DEQUANT (soft fall-through) - OneOf containing BF16 -> KEEP_NATIVE; otherwise DEQUANT The Require error messages name the failing target explicitly, suggest the working alternative (Require(FP32) / Any), and point at the missing backing class for the FP16 case — fail-fast with diagnostic detail, per the RFC. Round-trip test guards the parity between Bf16LoadPolicy.toDTypePolicy() (W2) and this mapper: every enum arm survives the round trip through DTypePolicy intact. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0d12193 commit 045cb20

2 files changed

Lines changed: 177 additions & 0 deletions

File tree

skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import sk.ainet.lang.tensor.Shape
88
import sk.ainet.lang.tensor.Tensor
99
import sk.ainet.lang.tensor.data.Bf16DenseTensorData
1010
import sk.ainet.lang.tensor.data.TensorData
11+
import sk.ainet.lang.types.BF16
1112
import sk.ainet.lang.types.DType
13+
import sk.ainet.lang.types.DTypePolicy
14+
import sk.ainet.lang.types.FP16
1215
import sk.ainet.lang.types.FP32
1316
import sk.ainet.lang.types.Int32
1417
import sk.ainet.lang.types.Int8
@@ -286,4 +289,64 @@ class SafeTensorsParametersLoader(
286289
}
287290
}
288291
}
292+
293+
companion object {
294+
295+
/**
296+
* Constructs a SafeTensorsParametersLoader from a generalised
297+
* [DTypePolicy] instead of the BF16-specific [Bf16LoadPolicy].
298+
* Bridge for the policy-driven loader path described in the
299+
* dtype-policy RFC (#615).
300+
*
301+
* Policy → behaviour mapping (BF16 source tensors only —
302+
* other dtypes are handled per the per-arm `require` checks
303+
* in [load]):
304+
* - [DTypePolicy.Any]: BF16 dequants to FP32 (the existing
305+
* adaptive default).
306+
* - [DTypePolicy.Require] target = `BF16`: KEEP_NATIVE.
307+
* - [DTypePolicy.Require] target = `FP32`: DEQUANT_TO_FP32.
308+
* - [DTypePolicy.Require] target = `FP16`: throws — F16
309+
* KEEP_NATIVE is a follow-up (no `Fp16DenseTensorData`
310+
* yet); use `Require(FP32)` if you want F16 dequanted, or
311+
* `Any` to inherit the adaptive default.
312+
* - [DTypePolicy.Require] target = anything else: throws —
313+
* SafeTensors can't fabricate dtypes the file doesn't carry.
314+
* - [DTypePolicy.Prefer] target = `BF16`: KEEP_NATIVE.
315+
* - [DTypePolicy.Prefer] target = anything else: DEQUANT_TO_FP32
316+
* (the soft path falls through).
317+
* - [DTypePolicy.OneOf] containing `BF16`: KEEP_NATIVE.
318+
* - [DTypePolicy.OneOf] without `BF16`: DEQUANT_TO_FP32.
319+
*/
320+
fun withPolicy(
321+
sourceProvider: () -> RandomAccessSource,
322+
policy: DTypePolicy,
323+
onProgress: (current: Long, total: Long, message: String?) -> Unit = { _, _, _ -> },
324+
): SafeTensorsParametersLoader = SafeTensorsParametersLoader(
325+
sourceProvider = sourceProvider,
326+
onProgress = onProgress,
327+
bf16Policy = mapPolicyToBf16(policy),
328+
)
329+
330+
internal fun mapPolicyToBf16(policy: DTypePolicy): Bf16LoadPolicy = when (policy) {
331+
DTypePolicy.Any -> Bf16LoadPolicy.DEQUANT_TO_FP32
332+
is DTypePolicy.Require -> when (policy.target) {
333+
BF16 -> Bf16LoadPolicy.KEEP_NATIVE
334+
FP32 -> Bf16LoadPolicy.DEQUANT_TO_FP32
335+
FP16 -> throw IllegalArgumentException(
336+
"SafeTensorsParametersLoader: Require(FP16) is not supported — " +
337+
"F16 KEEP_NATIVE has no Fp16DenseTensorData backing yet. " +
338+
"Use Require(FP32) to dequant F16 sources, or Any to inherit the adaptive default.",
339+
)
340+
else -> throw IllegalArgumentException(
341+
"SafeTensorsParametersLoader: Require(${policy.target.name}) is not satisfiable — " +
342+
"the loader produces FP32 / BF16 / Int32 / Int8 tensors depending on source dtype; " +
343+
"it cannot fabricate ${policy.target.name} from arbitrary sources.",
344+
)
345+
}
346+
is DTypePolicy.Prefer -> if (policy.target == BF16) Bf16LoadPolicy.KEEP_NATIVE
347+
else Bf16LoadPolicy.DEQUANT_TO_FP32
348+
is DTypePolicy.OneOf -> if (BF16 in policy.allowed) Bf16LoadPolicy.KEEP_NATIVE
349+
else Bf16LoadPolicy.DEQUANT_TO_FP32
350+
}
351+
}
289352
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package sk.ainet.io.safetensors
2+
3+
import kotlin.test.Test
4+
import kotlin.test.assertEquals
5+
import kotlin.test.assertFailsWith
6+
import sk.ainet.lang.types.BF16
7+
import sk.ainet.lang.types.DTypePolicy
8+
import sk.ainet.lang.types.FP16
9+
import sk.ainet.lang.types.FP32
10+
import sk.ainet.lang.types.Int8
11+
12+
/**
13+
* Unit tests for the `DTypePolicy` → `Bf16LoadPolicy` adapter in
14+
* [SafeTensorsParametersLoader.mapPolicyToBf16]. The `withPolicy`
15+
* factory is a thin wrapper over this mapper plus the existing
16+
* constructor; testing the mapper covers the routing logic without
17+
* needing a real SafeTensors fixture.
18+
*/
19+
class SafeTensorsParametersLoaderPolicyTest {
20+
21+
@Test
22+
fun any_maps_to_dequant_to_fp32() {
23+
assertEquals(
24+
Bf16LoadPolicy.DEQUANT_TO_FP32,
25+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Any),
26+
)
27+
}
28+
29+
@Test
30+
fun require_bf16_maps_to_keep_native() {
31+
assertEquals(
32+
Bf16LoadPolicy.KEEP_NATIVE,
33+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Require(BF16)),
34+
)
35+
}
36+
37+
@Test
38+
fun require_fp32_maps_to_dequant() {
39+
assertEquals(
40+
Bf16LoadPolicy.DEQUANT_TO_FP32,
41+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Require(FP32)),
42+
)
43+
}
44+
45+
@Test
46+
fun require_fp16_fails_with_explicit_message() {
47+
val ex = assertFailsWith<IllegalArgumentException> {
48+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Require(FP16))
49+
}
50+
// The error message must point the operator at the alternative —
51+
// RFC says "fail-fast with clear diagnostics," not just throw.
52+
val msg = ex.message ?: ""
53+
assertEquals(true, msg.contains("Require(FP16)"), "msg: $msg")
54+
assertEquals(true, msg.contains("Fp16DenseTensorData"), "msg: $msg")
55+
}
56+
57+
@Test
58+
fun require_unsupported_target_fails_with_explicit_message() {
59+
val ex = assertFailsWith<IllegalArgumentException> {
60+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Require(Int8))
61+
}
62+
val msg = ex.message ?: ""
63+
assertEquals(true, msg.contains("Require(Int8)"), "msg: $msg")
64+
assertEquals(true, msg.contains("cannot fabricate"), "msg: $msg")
65+
}
66+
67+
@Test
68+
fun prefer_bf16_maps_to_keep_native() {
69+
assertEquals(
70+
Bf16LoadPolicy.KEEP_NATIVE,
71+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Prefer(BF16)),
72+
)
73+
}
74+
75+
@Test
76+
fun prefer_fp32_or_anything_else_maps_to_dequant() {
77+
assertEquals(
78+
Bf16LoadPolicy.DEQUANT_TO_FP32,
79+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Prefer(FP32)),
80+
)
81+
assertEquals(
82+
Bf16LoadPolicy.DEQUANT_TO_FP32,
83+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Prefer(FP16)),
84+
"Prefer is soft — unsatisfiable preferences fall through silently, no throw",
85+
)
86+
}
87+
88+
@Test
89+
fun oneOf_with_bf16_maps_to_keep_native() {
90+
assertEquals(
91+
Bf16LoadPolicy.KEEP_NATIVE,
92+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.OneOf(setOf(BF16, FP32))),
93+
)
94+
}
95+
96+
@Test
97+
fun oneOf_without_bf16_maps_to_dequant() {
98+
assertEquals(
99+
Bf16LoadPolicy.DEQUANT_TO_FP32,
100+
SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.OneOf(setOf(FP32, FP16))),
101+
)
102+
}
103+
104+
@Test
105+
fun parity_with_bf16LoadPolicy_toDTypePolicy() {
106+
// Round-trip property: the BF16 enum's adapter should land on a
107+
// policy that the inverse mapper sends back to the original enum.
108+
for (arm in Bf16LoadPolicy.entries) {
109+
val asDTypePolicy = arm.toDTypePolicy()
110+
val back = SafeTensorsParametersLoader.mapPolicyToBf16(asDTypePolicy)
111+
assertEquals(arm, back, "round-trip failed for $arm via $asDTypePolicy")
112+
}
113+
}
114+
}

0 commit comments

Comments
 (0)