Skip to content

Commit aa77b4d

Browse files
Merge pull request #582 from SKaiNET-developers/feature/sharded-loadtensorstoragemapped
feat(safetensors): loadTensorStorageMapped on the sharded reader
2 parents b25d7e2 + 75e6b49 commit aa77b4d

2 files changed

Lines changed: 171 additions & 0 deletions

File tree

skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/StreamingShardedSafeTensorsReader.kt

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import kotlinx.coroutines.flow.MutableSharedFlow
55
import sk.ainet.io.model.LoadingProgress
66
import sk.ainet.io.model.LoadingStage
77
import sk.ainet.io.model.ProgressReportingLoader
8+
import sk.ainet.lang.tensor.storage.TensorStorage
89

910
/**
1011
* Streaming reader for sharded/multi-file SafeTensors models.
@@ -96,6 +97,54 @@ public class StreamingShardedSafeTensorsReader private constructor(
9697
return reader.loadTensorData(tensor.name)
9798
}
9899

100+
/**
101+
* Same shape as [loadTensorData] but returns a file-backed
102+
* [TensorStorage] instead of a heap [ByteArray]. Lets callers
103+
* memory-map the tensor's byte range straight from the shard file
104+
* without going through a 2 GB-capped `ByteArray` round-trip.
105+
*
106+
* The returned [TensorStorage] holds a
107+
* [sk.ainet.lang.tensor.storage.BufferHandle.FileBacked] that
108+
* references the shard file by absolute path; callers (or the
109+
* runtime that consumes the storage) own the mmap lifecycle.
110+
*
111+
* Sharded analog of
112+
* [StreamingSafeTensorsReader.loadTensorStorageMapped]. The
113+
* shard's file path is resolved internally from the index — the
114+
* caller doesn't need to know which physical file contains the
115+
* tensor.
116+
*
117+
* @param tensor The tensor info from [tensors].
118+
* @return [TensorStorage] descriptor with a file-backed buffer
119+
* handle pointing at the shard file's tensor byte range.
120+
* @throws IllegalStateException if the containing shard was not
121+
* loaded, or if the per-shard reader does not surface the
122+
* tensor (consistency check).
123+
*/
124+
public fun loadTensorStorageMapped(tensor: ShardedTensorInfo): TensorStorage {
125+
val reader = shardReaders[tensor.shardFilename]
126+
?: throw IllegalStateException("Shard not loaded: ${tensor.shardFilename}")
127+
val streamingTensor = reader.tensors.firstOrNull { it.name == tensor.name }
128+
?: throw IllegalStateException(
129+
"Tensor '${tensor.name}' not found in shard '${tensor.shardFilename}'",
130+
)
131+
val path = resolveShardPath(tensor.shardFilename)
132+
return reader.loadTensorStorageMapped(streamingTensor, path)
133+
}
134+
135+
/**
136+
* Convenience overload for [loadTensorStorageMapped] that looks up
137+
* the tensor by name. Mirrors the [loadTensorData] name-based
138+
* overload.
139+
*
140+
* @throws IllegalArgumentException if no tensor matches [name].
141+
*/
142+
public fun loadTensorStorageMapped(name: String): TensorStorage {
143+
val tensor = _tensors.firstOrNull { it.name == name }
144+
?: throw IllegalArgumentException("Tensor not found: $name")
145+
return loadTensorStorageMapped(tensor)
146+
}
147+
99148
override fun close() {
100149
shardReaders.values.forEach { it.close() }
101150
shardReaders.clear()
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package sk.ainet.io.safetensors
2+
3+
import java.nio.file.Files
4+
import java.nio.file.Path
5+
import kotlin.io.path.deleteIfExists
6+
import kotlin.io.path.exists
7+
import kotlin.test.Test
8+
import kotlin.test.assertEquals
9+
import kotlin.test.assertNotNull
10+
import kotlin.test.assertTrue
11+
import kotlinx.coroutines.runBlocking
12+
import kotlinx.io.buffered
13+
import kotlinx.io.files.Path as IoPath
14+
import kotlinx.io.files.SystemFileSystem
15+
import sk.ainet.lang.tensor.storage.BufferHandle
16+
17+
/**
18+
* JVM-only end-to-end coverage of [StreamingShardedSafeTensorsReader],
19+
* focused on the `loadTensorStorageMapped` entry points added to support
20+
* file-backed `TensorStorage` reads on tensors that exceed the JVM
21+
* `ByteArray` limit (used by the Gemma 4 PLE token-embedding table).
22+
*
23+
* Builds a real tiny single-shard SafeTensors file via [SafeTensorsWriter]
24+
* and a hand-crafted index JSON, opens via the sharded reader, and
25+
* verifies the returned [sk.ainet.lang.tensor.storage.TensorStorage]
26+
* carries the right metadata and a [BufferHandle.FileBacked] handle.
27+
*/
28+
class StreamingShardedSafeTensorsReaderJvmTest {
29+
30+
@Test
31+
fun `loadTensorStorageMapped returns file-backed storage with correct metadata`() {
32+
val tempDir = Files.createTempDirectory("sharded-mapped-")
33+
try {
34+
val shardName = "model.safetensors"
35+
val shardPath = tempDir.resolve(shardName)
36+
val indexPath = tempDir.resolve("model.safetensors.index.json")
37+
38+
// Build a 2x2 F32 tensor and write it as a single-shard SafeTensors file.
39+
val data = floatArrayOf(1f, 2f, 3f, 4f)
40+
writeSingleTensor(shardPath, name = "test.weight", shape = listOf(2L, 2L), data = data)
41+
42+
// Hand-craft the index. The sharded reader treats this as a 1-shard
43+
// sharded model — equivalent to how a single-file gemma-4-e2b-it
44+
// checkpoint gets routed through the sharded loader.
45+
val totalSize = Files.size(shardPath)
46+
val indexJson = """
47+
{
48+
"metadata": {"total_size": $totalSize},
49+
"weight_map": {"test.weight": "$shardName"}
50+
}
51+
""".trimIndent()
52+
Files.writeString(indexPath, indexJson)
53+
54+
runBlocking {
55+
StreamingShardedSafeTensorsReader.openFromIndex(indexPath.toString()).use { reader ->
56+
assertEquals(1, reader.tensors.size)
57+
assertEquals("test.weight", reader.tensors[0].name)
58+
59+
// 1) Look-up by name overload.
60+
val byName = reader.loadTensorStorageMapped("test.weight")
61+
assertEquals(listOf(2, 2), byName.shape.dimensions.toList())
62+
assertTrue(byName.isFileBacked)
63+
val handle = byName.buffer
64+
assertTrue(handle is BufferHandle.FileBacked, "Expected FileBacked, got ${handle::class}")
65+
handle as BufferHandle.FileBacked
66+
assertEquals(shardPath.toString(), handle.path)
67+
assertEquals(16L, handle.sizeInBytes) // 4 floats × 4 bytes
68+
69+
// 2) Look-up by ShardedTensorInfo overload — must produce
70+
// the same TensorStorage shape.
71+
val byInfo = reader.loadTensorStorageMapped(reader.tensors[0])
72+
assertEquals(byName.shape.dimensions.toList(), byInfo.shape.dimensions.toList())
73+
assertTrue(byInfo.isFileBacked)
74+
}
75+
}
76+
} finally {
77+
tempDir.toFile().deleteRecursively()
78+
}
79+
}
80+
81+
@Test
82+
fun `loadTensorStorageMapped throws on missing tensor name`() {
83+
val tempDir = Files.createTempDirectory("sharded-mapped-missing-")
84+
try {
85+
val shardPath = tempDir.resolve("model.safetensors")
86+
val indexPath = tempDir.resolve("model.safetensors.index.json")
87+
writeSingleTensor(shardPath, "real.weight", listOf(1L), floatArrayOf(1f))
88+
val total = Files.size(shardPath)
89+
Files.writeString(
90+
indexPath,
91+
"""{"metadata":{"total_size": $total},"weight_map":{"real.weight":"model.safetensors"}}""",
92+
)
93+
94+
runBlocking {
95+
StreamingShardedSafeTensorsReader.openFromIndex(indexPath.toString()).use { reader ->
96+
val thrown = runCatching { reader.loadTensorStorageMapped("missing.weight") }
97+
.exceptionOrNull()
98+
assertNotNull(thrown, "Expected an exception for missing tensor name")
99+
assertTrue(thrown is IllegalArgumentException)
100+
}
101+
}
102+
} finally {
103+
tempDir.toFile().deleteRecursively()
104+
}
105+
}
106+
107+
private fun writeSingleTensor(
108+
outputPath: Path,
109+
name: String,
110+
shape: List<Long>,
111+
data: FloatArray,
112+
) {
113+
outputPath.deleteIfExists()
114+
val ioPath = IoPath(outputPath.toString())
115+
SystemFileSystem.sink(ioPath).buffered().use { sink ->
116+
SafeTensorsWriter.write(sink) {
117+
tensorF32(name, shape, data)
118+
}
119+
}
120+
check(outputPath.exists()) { "Failed to write fixture at $outputPath" }
121+
}
122+
}

0 commit comments

Comments
 (0)