Skip to content

Commit c6c9be3

Browse files
committed
data: parameterize Hugging Face auth
1 parent ed357d7 commit c6c9be3

14 files changed

Lines changed: 277 additions & 29 deletions

File tree

docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,20 @@ dependencies {
4141

4242
`JvmDataSourceResolver` materializes remote artifacts into a cache and returns
4343
a `DataSourceArtifact` that opens a `kotlinx.io.Source`. Public Hugging Face
44-
files do not need credentials. Private files can use an `Authorization` header,
45-
or the JVM resolver will read `HF_TOKEN` / `HUGGING_FACE_HUB_TOKEN` from the
46-
environment when the URI provider is Hugging Face.
44+
files do not need credentials. Private files should pass an explicit
45+
`DataSourceAuthToken` on the request or resolver. Existing `Authorization`
46+
headers still take precedence. On JVM, the resolver can also read `HF_TOKEN` /
47+
`HUGGING_FACE_HUB_TOKEN` from the environment as an opt-in convenience fallback.
4748

4849
[source,kotlin]
4950
----
51+
import sk.ainet.data.source.DataSourceAuthToken
5052
import sk.ainet.data.source.DataSourceRequest
5153
import sk.ainet.data.source.JvmDataSourceResolver
5254
53-
val resolver = JvmDataSourceResolver()
55+
val resolver = JvmDataSourceResolver(
56+
huggingFaceToken = DataSourceAuthToken.from("hf_...")
57+
)
5458
val artifact = resolver.resolve(
5559
DataSourceRequest(
5660
uri = "hf+https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json"
@@ -71,6 +75,28 @@ try {
7175
val bytes = artifact.readBytes()
7276
----
7377

78+
For per-request credentials, pass the token directly on `DataSourceRequest`.
79+
This is useful when one resolver works with more than one private repository:
80+
81+
[source,kotlin]
82+
----
83+
val privateArtifact = resolver.resolve(
84+
DataSourceRequest(
85+
uri = "hf://datasets/your-org/private-dataset@main/data/train.bin",
86+
huggingFaceToken = DataSourceAuthToken.from("hf_...")
87+
)
88+
)
89+
----
90+
91+
To opt into JVM environment fallback:
92+
93+
[source,kotlin]
94+
----
95+
val resolver = JvmDataSourceResolver(
96+
useEnvironmentHuggingFaceToken = true
97+
)
98+
----
99+
74100
=== Use sources with built-in loaders
75101

76102
MNIST and Fashion-MNIST expose per-file URI overrides. CIFAR-10 exposes an
@@ -82,10 +108,12 @@ locations, so existing code keeps working.
82108
import sk.ainet.data.mnist.MNIST
83109
import sk.ainet.data.mnist.MNISTLoaderConfig
84110
111+
val token = "hf_..."
85112
val train = MNIST.loadTrain(
86113
MNISTLoaderConfig(
87114
trainImagesUri = "file:///datasets/mnist/train-images-idx3-ubyte",
88-
trainLabelsUri = "hf+https://huggingface.co/your-org/mnist-idx/resolve/main/train-labels-idx1-ubyte.gz"
115+
trainLabelsUri = "hf+https://huggingface.co/your-org/mnist-idx/resolve/main/train-labels-idx1-ubyte.gz",
116+
huggingFaceTokenProvider = { token }
89117
)
90118
)
91119

skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext
66
import sk.ainet.context.ExecutionContext
77
import sk.ainet.data.DataBatch
88
import sk.ainet.data.Dataset
9+
import sk.ainet.data.common.DatasetHuggingFaceTokenProvider
910
import sk.ainet.lang.tensor.Shape
1011
import sk.ainet.lang.tensor.Tensor
1112
import sk.ainet.lang.types.DType
@@ -145,7 +146,9 @@ public data class CIFAR10Dataset(
145146
public data class CIFAR10LoaderConfig(
146147
val cacheDir: String = "cifar10-data",
147148
val useCache: Boolean = true,
148-
val archiveUri: String = CIFAR10Constants.DOWNLOAD_URL
149+
val archiveUri: String = CIFAR10Constants.DOWNLOAD_URL,
150+
val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null,
151+
val useEnvironmentHuggingFaceToken: Boolean = false
149152
)
150153

151154
/**
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package sk.ainet.data.common
2+
3+
/**
4+
* Supplies a Hugging Face token for built-in dataset loaders when their source
5+
* URIs point at private Hugging Face artifacts.
6+
*/
7+
public fun interface DatasetHuggingFaceTokenProvider {
8+
public fun token(): String?
9+
}

skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext
66
import sk.ainet.context.ExecutionContext
77
import sk.ainet.data.DataBatch
88
import sk.ainet.data.Dataset
9+
import sk.ainet.data.common.DatasetHuggingFaceTokenProvider
910
import sk.ainet.lang.tensor.Shape
1011
import sk.ainet.lang.tensor.Tensor
1112
import sk.ainet.lang.types.DType
@@ -150,7 +151,9 @@ public data class FashionMNISTLoaderConfig(
150151
val trainImagesUri: String = FashionMNISTConstants.TRAIN_IMAGES_URL,
151152
val trainLabelsUri: String = FashionMNISTConstants.TRAIN_LABELS_URL,
152153
val testImagesUri: String = FashionMNISTConstants.TEST_IMAGES_URL,
153-
val testLabelsUri: String = FashionMNISTConstants.TEST_LABELS_URL
154+
val testLabelsUri: String = FashionMNISTConstants.TEST_LABELS_URL,
155+
val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null,
156+
val useEnvironmentHuggingFaceToken: Boolean = false
154157
)
155158

156159
/**

skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext
66
import sk.ainet.context.ExecutionContext
77
import sk.ainet.data.DataBatch
88
import sk.ainet.data.Dataset
9+
import sk.ainet.data.common.DatasetHuggingFaceTokenProvider
910
import sk.ainet.lang.tensor.Shape
1011
import sk.ainet.lang.tensor.Tensor
1112
import sk.ainet.lang.types.DType
@@ -128,7 +129,9 @@ public data class MNISTLoaderConfig(
128129
val trainImagesUri: String = MNISTConstants.TRAIN_IMAGES_URL,
129130
val trainLabelsUri: String = MNISTConstants.TRAIN_LABELS_URL,
130131
val testImagesUri: String = MNISTConstants.TEST_IMAGES_URL,
131-
val testLabelsUri: String = MNISTConstants.TEST_LABELS_URL
132+
val testLabelsUri: String = MNISTConstants.TEST_LABELS_URL,
133+
val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null,
134+
val useEnvironmentHuggingFaceToken: Boolean = false
132135
)
133136

134137
/**

skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@ import java.io.FileOutputStream
1515
* @property config The configuration for the CIFAR-10 loader.
1616
*/
1717
public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon(config) {
18-
private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache)
18+
private val sources = JvmDatasetSourceReader(
19+
cacheDir = config.cacheDir,
20+
useCache = config.useCache,
21+
huggingFaceTokenProvider = config.huggingFaceTokenProvider,
22+
useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken
23+
)
1924

2025
/**
2126
* Downloads the CIFAR-10 archive and extracts the specified batch file.

skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package sk.ainet.data.common
22

33
import sk.ainet.data.source.CachePolicy
4+
import sk.ainet.data.source.DataSourceAuthToken
45
import sk.ainet.data.source.DataSourceRequest
56
import sk.ainet.data.source.JvmDataSourceResolver
67
import java.io.ByteArrayInputStream
@@ -9,16 +10,22 @@ import java.util.zip.GZIPInputStream
910

1011
internal class JvmDatasetSourceReader(
1112
cacheDir: String,
12-
useCache: Boolean
13+
useCache: Boolean,
14+
private val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null,
15+
useEnvironmentHuggingFaceToken: Boolean = false
1316
) {
14-
private val resolver = JvmDataSourceResolver(File(cacheDir, "sources"))
17+
private val resolver = JvmDataSourceResolver(
18+
cacheDir = File(cacheDir, "sources"),
19+
useEnvironmentHuggingFaceToken = useEnvironmentHuggingFaceToken
20+
)
1521
private val cachePolicy = if (useCache) CachePolicy.Use else CachePolicy.Refresh
1622

1723
suspend fun read(uri: String): ByteArray {
1824
val artifact = resolver.resolve(
1925
DataSourceRequest(
2026
uri = uri,
21-
cachePolicy = cachePolicy
27+
cachePolicy = cachePolicy,
28+
huggingFaceToken = DataSourceAuthToken.fromOrNull(huggingFaceTokenProvider?.token())
2229
)
2330
)
2431
return artifact.readBytes()

skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@ import sk.ainet.data.common.JvmDatasetSourceReader
88
* @property config The configuration for the Fashion-MNIST loader.
99
*/
1010
public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMNISTLoaderCommon(config) {
11-
private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache)
11+
private val sources = JvmDatasetSourceReader(
12+
cacheDir = config.cacheDir,
13+
useCache = config.useCache,
14+
huggingFaceTokenProvider = config.huggingFaceTokenProvider,
15+
useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken
16+
)
1217

1318
/**
1419
* Resolves, caches, and decompresses a file when needed.

skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@ import sk.ainet.data.common.JvmDatasetSourceReader
88
* @property config The configuration for the MNIST loader.
99
*/
1010
public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(config) {
11-
private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache)
11+
private val sources = JvmDatasetSourceReader(
12+
cacheDir = config.cacheDir,
13+
useCache = config.useCache,
14+
huggingFaceTokenProvider = config.huggingFaceTokenProvider,
15+
useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken
16+
)
1217

1318
/**
1419
* Resolves, caches, and decompresses a file when needed.

skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,33 @@ public data class HuggingFaceLocation(
5050
public val path: String?
5151
)
5252

53+
/**
54+
* Authentication token for provider-specific data source requests.
55+
*
56+
* The raw value is intentionally hidden from [toString] output so tokens are
57+
* not leaked when requests or configs are logged.
58+
*/
59+
public class DataSourceAuthToken private constructor(
60+
private val value: String
61+
) {
62+
override fun toString(): String = "DataSourceAuthToken(***)"
63+
64+
internal fun authorizationHeaderValue(): String = "Bearer $value"
65+
66+
public companion object {
67+
public fun from(value: String): DataSourceAuthToken {
68+
val normalized = value.trim()
69+
require(normalized.isNotEmpty()) { "Data source auth token cannot be blank" }
70+
return DataSourceAuthToken(normalized)
71+
}
72+
73+
public fun fromOrNull(value: String?): DataSourceAuthToken? {
74+
val normalized = value?.trim()?.takeIf { it.isNotEmpty() } ?: return null
75+
return DataSourceAuthToken(normalized)
76+
}
77+
}
78+
}
79+
5380
/**
5481
* A normalized, provider-aware source URI.
5582
*/
@@ -70,7 +97,8 @@ public data class DataSourceRequest(
7097
public val uri: String,
7198
public val cachePolicy: CachePolicy = CachePolicy.Use,
7299
public val expectedSha256: String? = null,
73-
public val headers: Map<String, String> = emptyMap()
100+
public val headers: Map<String, String> = emptyMap(),
101+
public val huggingFaceToken: DataSourceAuthToken? = null
74102
)
75103

76104
/**

0 commit comments

Comments
 (0)