Skip to content

Commit 82156a1

Browse files
committed
data: share source resolver core
1 parent 8f4aaee commit 82156a1

8 files changed

Lines changed: 404 additions & 172 deletions

File tree

skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@ package sk.ainet.data.cifar10
22

33
import kotlinx.coroutines.Dispatchers
44
import kotlinx.coroutines.withContext
5-
import sk.ainet.data.source.CachePolicy
6-
import sk.ainet.data.source.DataSourceRequest
7-
import sk.ainet.data.source.JvmDataSourceResolver
8-
import java.io.ByteArrayInputStream
5+
import sk.ainet.data.common.JvmDatasetSourceReader
6+
import sk.ainet.data.common.gunzip
97
import java.io.File
108
import java.io.FileOutputStream
11-
import java.util.zip.GZIPInputStream
129

1310
/**
1411
* JVM implementation of the CIFAR-10 loader.
@@ -18,7 +15,7 @@ import java.util.zip.GZIPInputStream
1815
* @property config The configuration for the CIFAR-10 loader.
1916
*/
2017
public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon(config) {
21-
private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources"))
18+
private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache)
2219

2320
/**
2421
* Downloads the CIFAR-10 archive and extracts the specified batch file.
@@ -43,14 +40,8 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon
4340

4441
// Check if we need to resolve and extract the archive
4542
if (!extractedDir.exists() || !config.useCache) {
46-
val archive = resolver.resolve(
47-
DataSourceRequest(
48-
uri = config.archiveUri,
49-
cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh
50-
)
51-
)
5243
println("Extracting CIFAR-10 archive...")
53-
extractTarGz(archive.readBytes(), cacheDir.path)
44+
extractTarGz(sources.read(config.archiveUri), cacheDir.path)
5445
}
5546

5647
if (!batchFile.exists()) {
@@ -68,14 +59,7 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon
6859
*/
6960
private fun extractTarGz(archiveBytes: ByteArray, outputDir: String) {
7061
val outputDirFile = File(outputDir)
71-
72-
// First, decompress gzip to get the tar content
73-
val tarBytes = GZIPInputStream(ByteArrayInputStream(archiveBytes)).use { gzipIn ->
74-
gzipIn.readBytes()
75-
}
76-
77-
// Parse the TAR archive
78-
extractTar(tarBytes, outputDirFile)
62+
extractTar(archiveBytes.gunzip(), outputDirFile)
7963
}
8064

8165
/**
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package sk.ainet.data.common
2+
3+
import sk.ainet.data.source.CachePolicy
4+
import sk.ainet.data.source.DataSourceRequest
5+
import sk.ainet.data.source.JvmDataSourceResolver
6+
import java.io.ByteArrayInputStream
7+
import java.io.File
8+
import java.util.zip.GZIPInputStream
9+
10+
internal class JvmDatasetSourceReader(
11+
cacheDir: String,
12+
useCache: Boolean
13+
) {
14+
private val resolver = JvmDataSourceResolver(File(cacheDir, "sources"))
15+
private val cachePolicy = if (useCache) CachePolicy.Use else CachePolicy.Refresh
16+
17+
suspend fun read(uri: String): ByteArray {
18+
val artifact = resolver.resolve(
19+
DataSourceRequest(
20+
uri = uri,
21+
cachePolicy = cachePolicy
22+
)
23+
)
24+
return artifact.readBytes()
25+
}
26+
27+
suspend fun readGzipDecoded(uri: String): ByteArray = read(uri).gunzipIfNeeded()
28+
}
29+
30+
internal fun ByteArray.gunzip(): ByteArray {
31+
return GZIPInputStream(ByteArrayInputStream(this)).use { it.readBytes() }
32+
}
33+
34+
internal fun ByteArray.gunzipIfNeeded(): ByteArray {
35+
return if (isGzip()) gunzip() else this
36+
}
37+
38+
private fun ByteArray.isGzip(): Boolean {
39+
return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte()
40+
}

skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
package sk.ainet.data.fashionmnist
22

3-
import kotlinx.coroutines.Dispatchers
4-
import kotlinx.coroutines.withContext
5-
import sk.ainet.data.source.CachePolicy
6-
import sk.ainet.data.source.DataSourceRequest
7-
import sk.ainet.data.source.JvmDataSourceResolver
8-
import java.io.ByteArrayInputStream
9-
import java.util.zip.GZIPInputStream
10-
import java.io.File
3+
import sk.ainet.data.common.JvmDatasetSourceReader
114

125
/**
136
* JVM implementation of the Fashion-MNIST loader.
147
*
158
* @property config The configuration for the Fashion-MNIST loader.
169
*/
1710
public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMNISTLoaderCommon(config) {
18-
private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources"))
11+
private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache)
1912

2013
/**
2114
* Resolves, caches, and decompresses a file when needed.
@@ -24,23 +17,8 @@ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMN
2417
* @param filename The name of the file to save.
2518
* @return The bytes of the decompressed file.
2619
*/
27-
override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) {
28-
val artifact = resolver.resolve(
29-
DataSourceRequest(
30-
uri = url,
31-
cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh
32-
)
33-
)
34-
return@withContext maybeGunzip(artifact.readBytes())
35-
}
36-
37-
private fun maybeGunzip(bytes: ByteArray): ByteArray {
38-
if (!bytes.isGzip()) return bytes
39-
return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() }
40-
}
41-
42-
private fun ByteArray.isGzip(): Boolean {
43-
return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte()
20+
override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray {
21+
return sources.readGzipDecoded(url)
4422
}
4523

4624
public companion object {

skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
package sk.ainet.data.mnist
22

3-
import kotlinx.coroutines.Dispatchers
4-
import kotlinx.coroutines.withContext
5-
import sk.ainet.data.source.CachePolicy
6-
import sk.ainet.data.source.DataSourceRequest
7-
import sk.ainet.data.source.JvmDataSourceResolver
8-
import java.io.ByteArrayInputStream
9-
import java.util.zip.GZIPInputStream
10-
import java.io.File
3+
import sk.ainet.data.common.JvmDatasetSourceReader
114

125
/**
136
* JVM implementation of the MNIST loader.
147
*
158
* @property config The configuration for the MNIST loader.
169
*/
1710
public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(config) {
18-
private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources"))
11+
private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache)
1912

2013
/**
2114
* Resolves, caches, and decompresses a file when needed.
@@ -24,23 +17,8 @@ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi
2417
* @param filename The name of the file to save.
2518
* @return The bytes of the decompressed file.
2619
*/
27-
override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) {
28-
val artifact = resolver.resolve(
29-
DataSourceRequest(
30-
uri = url,
31-
cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh
32-
)
33-
)
34-
return@withContext maybeGunzip(artifact.readBytes())
35-
}
36-
37-
private fun maybeGunzip(bytes: ByteArray): ByteArray {
38-
if (!bytes.isGzip()) return bytes
39-
return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() }
40-
}
41-
42-
private fun ByteArray.isGzip(): Boolean {
43-
return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte()
20+
override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray {
21+
return sources.readGzipDecoded(url)
4422
}
4523

4624
public companion object {

skainet-data/skainet-data-source/build.gradle.kts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ kotlin {
2222

2323
commonTest.dependencies {
2424
implementation(libs.kotlin.test)
25+
implementation(libs.kotlinx.coroutines.test)
2526
}
2627

2728
jvmMain.dependencies {
@@ -31,8 +32,5 @@ kotlin {
3132
implementation(libs.kotlinx.coroutines.core.jvm)
3233
}
3334

34-
jvmTest.dependencies {
35-
implementation(libs.kotlinx.coroutines.test)
36-
}
3735
}
3836
}
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package sk.ainet.data.source
2+
3+
/**
4+
* Fetches a remote URI into memory. Kept injectable so tests and applications
5+
* can provide their own HTTP stack or policy layer.
6+
*/
7+
public fun interface RemoteDataSourceFetcher {
8+
public suspend fun fetch(uri: String, headers: Map<String, String>): ByteArray
9+
}
10+
11+
/**
12+
* Adds platform or application-specific headers to a resolved remote request.
13+
*/
14+
public fun interface DataSourceHeaderProvider {
15+
public fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map<String, String>
16+
}
17+
18+
/**
19+
* Computes checksums for integrity verification without tying resolver policy
20+
* to a concrete platform crypto API.
21+
*/
22+
public fun interface DataSourceChecksum {
23+
public fun sha256Hex(bytes: ByteArray): String
24+
}
25+
26+
/**
27+
* Platform storage adapter used by [DefaultDataSourceResolver].
28+
*/
29+
public interface DataSourceByteStore {
30+
public suspend fun readLocal(path: String): DataSourceStoredArtifact?
31+
public suspend fun readCache(cacheKey: String): DataSourceStoredArtifact?
32+
public suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact
33+
}
34+
35+
/**
36+
* A platform materialized artifact used by the common resolver core.
37+
*/
38+
public class DataSourceStoredArtifact(
39+
public val localPath: String?,
40+
public val sizeBytes: Long?,
41+
private val byteReader: suspend () -> ByteArray
42+
) {
43+
public suspend fun readBytes(): ByteArray = byteReader()
44+
}
45+
46+
/**
47+
* Platform-neutral resolver implementation for local files, HTTP(S), and
48+
* Hugging Face source URIs. Storage, network, auth, and checksum details are
49+
* injected so this policy can be reused by each KMP target.
50+
*/
51+
public class DefaultDataSourceResolver(
52+
private val store: DataSourceByteStore,
53+
private val fetcher: RemoteDataSourceFetcher,
54+
private val checksum: DataSourceChecksum,
55+
private val headerProvider: DataSourceHeaderProvider = DataSourceHeaderProvider { request, _ ->
56+
request.headers
57+
}
58+
) : DataSourceResolver {
59+
override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact {
60+
val parsed = DataSourceUriParser.parse(request.uri)
61+
return when (parsed.provider) {
62+
DataSourceProvider.File -> resolveFile(request, parsed)
63+
DataSourceProvider.Http, DataSourceProvider.HuggingFace -> resolveRemote(request, parsed)
64+
}
65+
}
66+
67+
private suspend fun resolveFile(
68+
request: DataSourceRequest,
69+
parsed: ParsedDataSourceUri
70+
): DataSourceArtifact {
71+
val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}")
72+
val stored = store.readLocal(path)
73+
?: throw DataSourceException("Data source file not found: $path")
74+
request.expectedSha256?.let { verifySha256(stored.readBytes(), it, request.uri) }
75+
return stored.toArtifact(request, parsed, cacheHit = true)
76+
}
77+
78+
private suspend fun resolveRemote(
79+
request: DataSourceRequest,
80+
parsed: ParsedDataSourceUri
81+
): DataSourceArtifact {
82+
val canUseCache = request.cachePolicy == CachePolicy.Use || request.cachePolicy == CachePolicy.Offline
83+
if (canUseCache) {
84+
val cached = store.readCache(parsed.cacheKey)
85+
if (cached != null) {
86+
request.expectedSha256?.let { verifySha256(cached.readBytes(), it, request.uri) }
87+
return cached.toArtifact(request, parsed, cacheHit = true)
88+
}
89+
}
90+
91+
if (request.cachePolicy == CachePolicy.Offline) {
92+
throw DataSourceException("No cached artifact available for offline source: ${request.uri}")
93+
}
94+
95+
val bytes = fetcher.fetch(parsed.transportUri, headerProvider.headers(request, parsed))
96+
request.expectedSha256?.let { verifySha256(bytes, it, request.uri) }
97+
98+
if (request.cachePolicy == CachePolicy.Bypass) {
99+
return DataSourceArtifact(
100+
request = request,
101+
parsedUri = parsed,
102+
filename = parsed.filename,
103+
localPath = null,
104+
sizeBytes = bytes.size.toLong(),
105+
cacheHit = false,
106+
byteReader = { bytes }
107+
)
108+
}
109+
110+
val stored = store.writeCache(parsed.cacheKey, bytes)
111+
return stored.toArtifact(request, parsed, cacheHit = false)
112+
}
113+
114+
private suspend fun DataSourceStoredArtifact.toArtifact(
115+
request: DataSourceRequest,
116+
parsed: ParsedDataSourceUri,
117+
cacheHit: Boolean
118+
): DataSourceArtifact {
119+
return DataSourceArtifact(
120+
request = request,
121+
parsedUri = parsed,
122+
filename = parsed.filename,
123+
localPath = localPath,
124+
sizeBytes = sizeBytes,
125+
cacheHit = cacheHit,
126+
byteReader = { readBytes() }
127+
)
128+
}
129+
130+
private fun verifySha256(bytes: ByteArray, expected: String, uri: String) {
131+
val actual = checksum.sha256Hex(bytes)
132+
if (!actual.equals(expected, ignoreCase = true)) {
133+
throw DataSourceException(
134+
"SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual"
135+
)
136+
}
137+
}
138+
}

0 commit comments

Comments
 (0)