|
| 1 | +package sk.ainet.io.safetensors |
| 2 | + |
| 3 | +import java.nio.file.Files |
| 4 | +import java.nio.file.Path |
| 5 | +import kotlin.io.path.deleteIfExists |
| 6 | +import kotlin.io.path.exists |
| 7 | +import kotlin.test.Test |
| 8 | +import kotlin.test.assertEquals |
| 9 | +import kotlin.test.assertNotNull |
| 10 | +import kotlin.test.assertTrue |
| 11 | +import kotlinx.coroutines.runBlocking |
| 12 | +import kotlinx.io.buffered |
| 13 | +import kotlinx.io.files.Path as IoPath |
| 14 | +import kotlinx.io.files.SystemFileSystem |
| 15 | +import sk.ainet.lang.tensor.storage.BufferHandle |
| 16 | + |
| 17 | +/** |
| 18 | + * JVM-only end-to-end coverage of [StreamingShardedSafeTensorsReader], |
| 19 | + * focused on the `loadTensorStorageMapped` entry points added to support |
| 20 | + * file-backed `TensorStorage` reads on tensors that exceed the JVM |
| 21 | + * `ByteArray` limit (used by the Gemma 4 PLE token-embedding table). |
| 22 | + * |
| 23 | + * Builds a real tiny single-shard SafeTensors file via [SafeTensorsWriter] |
| 24 | + * and a hand-crafted index JSON, opens via the sharded reader, and |
| 25 | + * verifies the returned [sk.ainet.lang.tensor.storage.TensorStorage] |
| 26 | + * carries the right metadata and a [BufferHandle.FileBacked] handle. |
| 27 | + */ |
| 28 | +class StreamingShardedSafeTensorsReaderJvmTest { |
| 29 | + |
| 30 | + @Test |
| 31 | + fun `loadTensorStorageMapped returns file-backed storage with correct metadata`() { |
| 32 | + val tempDir = Files.createTempDirectory("sharded-mapped-") |
| 33 | + try { |
| 34 | + val shardName = "model.safetensors" |
| 35 | + val shardPath = tempDir.resolve(shardName) |
| 36 | + val indexPath = tempDir.resolve("model.safetensors.index.json") |
| 37 | + |
| 38 | + // Build a 2x2 F32 tensor and write it as a single-shard SafeTensors file. |
| 39 | + val data = floatArrayOf(1f, 2f, 3f, 4f) |
| 40 | + writeSingleTensor(shardPath, name = "test.weight", shape = listOf(2L, 2L), data = data) |
| 41 | + |
| 42 | + // Hand-craft the index. The sharded reader treats this as a 1-shard |
| 43 | + // sharded model — equivalent to how a single-file gemma-4-e2b-it |
| 44 | + // checkpoint gets routed through the sharded loader. |
| 45 | + val totalSize = Files.size(shardPath) |
| 46 | + val indexJson = """ |
| 47 | + { |
| 48 | + "metadata": {"total_size": $totalSize}, |
| 49 | + "weight_map": {"test.weight": "$shardName"} |
| 50 | + } |
| 51 | + """.trimIndent() |
| 52 | + Files.writeString(indexPath, indexJson) |
| 53 | + |
| 54 | + runBlocking { |
| 55 | + StreamingShardedSafeTensorsReader.openFromIndex(indexPath.toString()).use { reader -> |
| 56 | + assertEquals(1, reader.tensors.size) |
| 57 | + assertEquals("test.weight", reader.tensors[0].name) |
| 58 | + |
| 59 | + // 1) Look-up by name overload. |
| 60 | + val byName = reader.loadTensorStorageMapped("test.weight") |
| 61 | + assertEquals(listOf(2, 2), byName.shape.dimensions.toList()) |
| 62 | + assertTrue(byName.isFileBacked) |
| 63 | + val handle = byName.buffer |
| 64 | + assertTrue(handle is BufferHandle.FileBacked, "Expected FileBacked, got ${handle::class}") |
| 65 | + handle as BufferHandle.FileBacked |
| 66 | + assertEquals(shardPath.toString(), handle.path) |
| 67 | + assertEquals(16L, handle.sizeInBytes) // 4 floats × 4 bytes |
| 68 | + |
| 69 | + // 2) Look-up by ShardedTensorInfo overload — must produce |
| 70 | + // the same TensorStorage shape. |
| 71 | + val byInfo = reader.loadTensorStorageMapped(reader.tensors[0]) |
| 72 | + assertEquals(byName.shape.dimensions.toList(), byInfo.shape.dimensions.toList()) |
| 73 | + assertTrue(byInfo.isFileBacked) |
| 74 | + } |
| 75 | + } |
| 76 | + } finally { |
| 77 | + tempDir.toFile().deleteRecursively() |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + @Test |
| 82 | + fun `loadTensorStorageMapped throws on missing tensor name`() { |
| 83 | + val tempDir = Files.createTempDirectory("sharded-mapped-missing-") |
| 84 | + try { |
| 85 | + val shardPath = tempDir.resolve("model.safetensors") |
| 86 | + val indexPath = tempDir.resolve("model.safetensors.index.json") |
| 87 | + writeSingleTensor(shardPath, "real.weight", listOf(1L), floatArrayOf(1f)) |
| 88 | + val total = Files.size(shardPath) |
| 89 | + Files.writeString( |
| 90 | + indexPath, |
| 91 | + """{"metadata":{"total_size": $total},"weight_map":{"real.weight":"model.safetensors"}}""", |
| 92 | + ) |
| 93 | + |
| 94 | + runBlocking { |
| 95 | + StreamingShardedSafeTensorsReader.openFromIndex(indexPath.toString()).use { reader -> |
| 96 | + val thrown = runCatching { reader.loadTensorStorageMapped("missing.weight") } |
| 97 | + .exceptionOrNull() |
| 98 | + assertNotNull(thrown, "Expected an exception for missing tensor name") |
| 99 | + assertTrue(thrown is IllegalArgumentException) |
| 100 | + } |
| 101 | + } |
| 102 | + } finally { |
| 103 | + tempDir.toFile().deleteRecursively() |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + private fun writeSingleTensor( |
| 108 | + outputPath: Path, |
| 109 | + name: String, |
| 110 | + shape: List<Long>, |
| 111 | + data: FloatArray, |
| 112 | + ) { |
| 113 | + outputPath.deleteIfExists() |
| 114 | + val ioPath = IoPath(outputPath.toString()) |
| 115 | + SystemFileSystem.sink(ioPath).buffered().use { sink -> |
| 116 | + SafeTensorsWriter.write(sink) { |
| 117 | + tensorF32(name, shape, data) |
| 118 | + } |
| 119 | + } |
| 120 | + check(outputPath.exists()) { "Failed to write fixture at $outputPath" } |
| 121 | + } |
| 122 | +} |
0 commit comments