Skip to content

Commit 2c3e5fa

Browse files
committed
Add kotlin translated with InteliJ from JLama's java
1 parent 22bacec commit 2c3e5fa

13 files changed

Lines changed: 2146 additions & 0 deletions

File tree

safetensors/build.gradle.kts

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import org.jetbrains.kotlin.gradle.ExperimentalKotlinGradlePluginApi
2+
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
3+
4+
plugins {
5+
alias(libs.plugins.kotlinMultiplatform)
6+
alias(libs.plugins.androidLibrary)
7+
alias(libs.plugins.vanniktech.mavenPublish)
8+
}
9+
10+
kotlin {
11+
jvm()
12+
androidTarget {
13+
publishLibraryVariants("release")
14+
@OptIn(ExperimentalKotlinGradlePluginApi::class)
15+
compilerOptions {
16+
jvmTarget.set(JvmTarget.JVM_1_8)
17+
}
18+
}
19+
iosX64()
20+
iosArm64()
21+
iosSimulatorArm64()
22+
wasmJs().nodejs()
23+
macosX64 ()
24+
linuxX64 ()
25+
26+
27+
sourceSets {
28+
val commonMain by getting {
29+
dependencies {
30+
implementation(libs.kotlinx.io.core)
31+
}
32+
}
33+
val commonTest by getting {
34+
dependencies {
35+
implementation(libs.kotlin.test)
36+
}
37+
}
38+
}
39+
}
40+
41+
android {
42+
namespace = "sk.ai.net.core"
43+
compileSdk = libs.versions.android.compileSdk.get().toInt()
44+
defaultConfig {
45+
minSdk = libs.versions.android.minSdk.get().toInt()
46+
}
47+
}
48+
49+
publishing {
50+
repositories {
51+
maven {
52+
name = "githubPackages"
53+
url = uri("https://maven.pkg.github.com/sk-ai-net/skainet")
54+
credentials {
55+
credentials(PasswordCredentials::class)
56+
}
57+
}
58+
}
59+
}
60+
61+
mavenPublishing {
62+
63+
coordinates(group.toString(), "gguf", version.toString())
64+
65+
pom {
66+
description.set("skainet")
67+
name.set(project.name)
68+
url.set("https://github.com/sk-ai-net/skainet/")
69+
licenses {
70+
license {
71+
name.set("MIT")
72+
distribution.set("repo")
73+
}
74+
}
75+
scm {
76+
url.set("https://github.com/sk-ai-net/skainet/")
77+
connection.set("scm:git:git@github.com:sk-ai-net/skainet.git")
78+
developerConnection.set("scm:git:ssh://git@github.com:sk-ai-net/skainet.git")
79+
}
80+
developers {
81+
developer {
82+
id.set("sk-ai-net")
83+
name.set("sk-ai-net")
84+
}
85+
}
86+
}
87+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package sk.ai.net.safetensors
2+
3+
interface BiMap<K : Any, V : Any> : Map<K, V> {
4+
override val values: Set<V>
5+
val inverse: BiMap<V, K>
6+
}
7+
8+
interface MutableBiMap<K : Any, V : Any> : BiMap<K, V>, MutableMap<K, V> {
9+
override val values: MutableSet<V>
10+
override val inverse: MutableBiMap<V, K>
11+
12+
fun forcePut(key: K, value: V): V?
13+
}
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
/*
2+
* Copyright 2024 T Jake Luciani
3+
*
4+
* The Jlama Project licenses this file to you under the Apache License,
5+
* version 2.0 (the "License"); you may not use this file except in compliance
6+
* with the License. You may obtain a copy of the License at:
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations
14+
* under the License.
15+
*/
16+
package sk.ai.net.safetensors
17+
18+
import com.github.tjake.jlama.math.ActivationFunction
19+
import kotlin.concurrent.Volatile
20+
21+
class Config(
22+
val contextLength: Int,
23+
val embeddingLength: Int,
24+
val hiddenLength: Int,
25+
val numberOfHeads: Int,
26+
val numberOfKeyValueHeads: Int,
27+
val numberOfLayers: Int,
28+
val layerNormEps: Float,
29+
val vocabularySize: Int,
30+
val bosToken: Int,
31+
eosTokens: List<Int>,
32+
//activationFunction: ActivationFunction.Type?,
33+
ropeFreqsTheta: Double?,
34+
ropeScalingFactor: Double?,
35+
classifcationLabels: Map<String, Int>,
36+
headSize: Int,
37+
finalLogitSoftCapping: Float?,
38+
attnLogitSoftCapping: Float?,
39+
residualMultiplier: Float?,
40+
attentionMultiplier: Float?,
41+
embeddingMultiplier: Float?,
42+
logitMultiplier: Float?
43+
) {
44+
val attentionLength: Int
45+
val headSize: Int
46+
val activationFunction: ActivationFunction.Type?
47+
val headGroupSize: Int
48+
val kvLength: Int
49+
val isGQA: Boolean
50+
val finalLogitSoftCapping: Float?
51+
val attnLogitSoftCapping: Float?
52+
val residualMultiplier: Float?
53+
val attentionMultiplier: Float?
54+
val embeddingMultiplier: Float?
55+
val logitMultiplier: Float?
56+
val eosTokens: List<Int?>?
57+
val classifcationLabels: BiMap<String, Int>
58+
59+
60+
61+
constructor(
62+
contextLength: Int,
63+
embeddingLength: Int,
64+
hiddenLength: Int,
65+
numberOfHeads: Int,
66+
numberOfKeyValueHeads: Int,
67+
numberOfLayers: Int,
68+
layerNormEps: Float,
69+
vocabularySize: Int,
70+
bosToken: Int,
71+
eosToken: List<Integer?>?,
72+
activationFunction: ActivationFunction.Type?,
73+
ropeFreqsTheta: Double?,
74+
ropeScalingFactor: Double?,
75+
headSize: Integer?,
76+
attnLogitSoftCapping: Float?,
77+
finalLogitSoftCapping: Float?
78+
) : this(
79+
contextLength,
80+
embeddingLength,
81+
hiddenLength,
82+
numberOfHeads,
83+
numberOfKeyValueHeads,
84+
numberOfLayers,
85+
layerNormEps,
86+
vocabularySize,
87+
bosToken,
88+
eosToken,
89+
activationFunction,
90+
ropeFreqsTheta,
91+
ropeScalingFactor,
92+
null,
93+
if (headSize == null) embeddingLength / numberOfHeads else headSize,
94+
attnLogitSoftCapping,
95+
finalLogitSoftCapping,
96+
null,
97+
null,
98+
null,
99+
null
100+
)
101+
102+
constructor(
103+
contextLength: Int,
104+
embeddingLength: Int,
105+
hiddenLength: Int,
106+
numberOfHeads: Int,
107+
numberOfKeyValueHeads: Int,
108+
numberOfLayers: Int,
109+
layerNormEps: Float,
110+
vocabularySize: Int,
111+
bosToken: Int,
112+
eosToken: List<Int>,
113+
//activationFunction: ActivationFunction.Type?,
114+
ropeFreqsTheta: Double?,
115+
ropeScalingFactor: Double?
116+
) : this(
117+
contextLength,
118+
embeddingLength,
119+
hiddenLength,
120+
numberOfHeads,
121+
numberOfKeyValueHeads,
122+
numberOfLayers,
123+
layerNormEps,
124+
vocabularySize,
125+
bosToken,
126+
eosToken,
127+
//activationFunction,
128+
ropeFreqsTheta,
129+
ropeScalingFactor,
130+
null,
131+
embeddingLength / numberOfHeads,
132+
null,
133+
null,
134+
null,
135+
null,
136+
null,
137+
null
138+
)
139+
140+
constructor(
141+
contextLength: Int,
142+
embeddingLength: Int,
143+
hiddenLength: Int,
144+
numberOfHeads: Int,
145+
numberOfKeyValueHeads: Int,
146+
numberOfLayers: Int,
147+
layerNormEps: Float,
148+
vocabularySize: Int,
149+
bosToken: Int,
150+
eosToken: List<Int>,
151+
//activationFunction: ActivationFunction.Type?,
152+
ropeFreqsTheta: Double?,
153+
ropeScalingFactor: Double?,
154+
residualMultiplier: Float?,
155+
attentionMultiplier: Float?,
156+
embeddingMultiplier: Float?,
157+
logitMultiplier: Float?
158+
) : this(
159+
contextLength,
160+
embeddingLength,
161+
hiddenLength,
162+
numberOfHeads,
163+
numberOfKeyValueHeads,
164+
numberOfLayers,
165+
layerNormEps,
166+
vocabularySize,
167+
bosToken,
168+
eosToken,
169+
activationFunction,
170+
ropeFreqsTheta,
171+
ropeScalingFactor,
172+
null,
173+
embeddingLength / numberOfHeads,
174+
null,
175+
null,
176+
residualMultiplier,
177+
attentionMultiplier,
178+
embeddingMultiplier,
179+
logitMultiplier
180+
)
181+
182+
constructor(
183+
contextLength: Int,
184+
embeddingLength: Int,
185+
hiddenLength: Int,
186+
numberOfHeads: Int,
187+
numberOfKeyValueHeads: Int,
188+
numberOfLayers: Int,
189+
layerNormEps: Float,
190+
vocabularySize: Int,
191+
bosToken: Int,
192+
eosToken: List<Int>,
193+
//activationFunction: ActivationFunction.Type?,
194+
ropeFreqsTheta: Double?,
195+
ropeScalingFactor: Double?,
196+
classifcationLabels: Map<String, Int>
197+
) : this(
198+
contextLength,
199+
embeddingLength,
200+
hiddenLength,
201+
numberOfHeads,
202+
numberOfKeyValueHeads,
203+
numberOfLayers,
204+
layerNormEps,
205+
vocabularySize,
206+
bosToken,
207+
eosToken,
208+
activationFunction,
209+
ropeFreqsTheta,
210+
ropeScalingFactor,
211+
classifcationLabels,
212+
embeddingLength / numberOfHeads,
213+
null,
214+
null,
215+
null,
216+
null,
217+
null,
218+
null
219+
)
220+
221+
init {
222+
this.attentionLength = numberOfHeads * headSize
223+
this.eosTokens = eosTokens
224+
this.headSize = headSize
225+
this.headGroupSize = numberOfHeads / numberOfKeyValueHeads
226+
this.kvLength = numberOfKeyValueHeads * headSize
227+
this.isGQA = numberOfKeyValueHeads < numberOfHeads
228+
//this.activationFunction = activationFunction
229+
230+
this.classifcationLabels = classifcationLabels
231+
// if (classifcationLabels == null) Optional.empty() else Optional.of(ImmutableBiMap.copyOf(classifcationLabels))
232+
233+
this.finalLogitSoftCapping = finalLogitSoftCapping
234+
this.attnLogitSoftCapping = attnLogitSoftCapping
235+
this.residualMultiplier = residualMultiplier
236+
this.attentionMultiplier = attentionMultiplier
237+
this.embeddingMultiplier = embeddingMultiplier
238+
this.logitMultiplier = logitMultiplier
239+
240+
// Set default values
241+
this.dctx = DistributedContext.builder(this).build()
242+
}
243+
244+
fun setDistributedContext(dctx: DistributedContext?) {
245+
this.dctx = dctx
246+
}
247+
248+
fun setWorkingDirectory(workingDirectory: File?) {
249+
if (workingDirectory == null) {
250+
this.workingDirectory = Files.createTempDir()
251+
this.workingDirectory.deleteOnExit()
252+
} else {
253+
Preconditions.checkArgument(workingDirectory.isDirectory())
254+
this.workingDirectory = workingDirectory
255+
}
256+
}
257+
258+
fun workingDirectory(): Optional<File?> {
259+
return Optional.ofNullable(this.workingDirectory)
260+
}
261+
262+
fun dctx(): DistributedContext? {
263+
return dctx
264+
}
265+
266+
fun maybeMapToGroupHead(head: Int): Int {
267+
if (!isGQA) return head
268+
return Math.floorDiv(head, headGroupSize)
269+
}
270+
271+
val isClassifier: Boolean
272+
get() = classifcationLabels.isPresent()
273+
}

0 commit comments

Comments
 (0)