Skip to content

Commit 14a88b3

Browse files
Merge pull request #10 from SKaiNET-developers/feature/apertus
Feature/apertus
2 parents 943e059 + 736ff5d commit 14a88b3

20 files changed

Lines changed: 4079 additions & 0 deletions

File tree

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
plugins {
2+
kotlin("jvm")
3+
alias(libs.plugins.shadow)
4+
application
5+
}
6+
7+
application {
8+
mainClass.set("sk.ainet.apps.kapertus.cli.MainKt")
9+
}
10+
11+
dependencies {
12+
implementation(project(":llm-runtime:kapertus"))
13+
}
14+
15+
tasks.withType<com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar> {
16+
archiveBaseName.set("kapertus")
17+
archiveClassifier.set("all")
18+
archiveVersion.set("")
19+
20+
manifest {
21+
attributes(
22+
"Main-Class" to "sk.ainet.apps.kapertus.cli.MainKt",
23+
"Add-Opens" to "java.base/jdk.internal.misc",
24+
"Multi-Release" to "true"
25+
)
26+
}
27+
28+
mergeServiceFiles()
29+
}
30+
31+
tasks.withType<Test>().configureEach {
32+
jvmArgs("--enable-preview", "--add-modules", "jdk.incubator.vector")
33+
}
34+
35+
tasks.withType<JavaExec>().configureEach {
36+
jvmArgs("--enable-preview", "--add-modules", "jdk.incubator.vector")
37+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl
2+
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
3+
4+
plugins {
5+
alias(libs.plugins.kotlinMultiplatform)
6+
alias(libs.plugins.androidMultiplatformLibrary)
7+
alias(libs.plugins.vanniktech.mavenPublish)
8+
alias(libs.plugins.kover)
9+
alias(libs.plugins.binary.compatibility.validator)
10+
}
11+
12+
kotlin {
13+
android {
14+
namespace = "sk.ainet.models.apertus"
15+
compileSdk = libs.versions.android.compileSdk.get().toInt()
16+
minSdk = libs.versions.android.minSdk.get().toInt()
17+
compilerOptions {
18+
jvmTarget.set(JvmTarget.JVM_11)
19+
}
20+
}
21+
22+
iosArm64()
23+
iosSimulatorArm64()
24+
macosArm64()
25+
linuxX64()
26+
linuxArm64()
27+
28+
jvm()
29+
30+
js {
31+
browser()
32+
}
33+
34+
@OptIn(ExperimentalWasmDsl::class)
35+
wasmJs {
36+
browser()
37+
}
38+
39+
@OptIn(ExperimentalWasmDsl::class)
40+
wasmWasi {
41+
nodejs()
42+
}
43+
44+
sourceSets {
45+
commonMain.dependencies {
46+
implementation(libs.skainet.lang.core)
47+
implementation(libs.skainet.io.core)
48+
implementation(libs.skainet.io.gguf)
49+
implementation(libs.skainet.io.safetensors)
50+
implementation(libs.skainet.compile.core)
51+
implementation(project(":llm-core"))
52+
implementation(libs.kotlinx.io.core)
53+
implementation(libs.kotlinx.coroutines)
54+
}
55+
56+
commonTest.dependencies {
57+
implementation(libs.kotlin.test)
58+
implementation(libs.skainet.backend.cpu)
59+
}
60+
61+
val jvmTest by getting {
62+
dependencies {
63+
implementation(libs.kotlin.test)
64+
implementation(libs.junit)
65+
implementation(libs.kotlinx.coroutines.test)
66+
implementation(libs.skainet.backend.cpu)
67+
}
68+
}
69+
}
70+
}
71+
72+
tasks.withType<Test>().configureEach {
73+
jvmArgs("--enable-preview", "--add-modules", "jdk.incubator.vector", "-XX:MaxDirectMemorySize=12g")
74+
maxHeapSize = "6g"
75+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package sk.ainet.models.apertus
2+
3+
import sk.ainet.lang.tensor.Tensor
4+
import sk.ainet.lang.types.DType
5+
6+
/**
7+
* Strategy interface for Apertus attention computation.
8+
*
9+
* Similar to LLaMA's AttentionBackend but receives Q/K after QK-norm has been applied.
10+
* Applies RoPE encoding, KV cache management, and GQA attention scoring.
11+
*
12+
* Contract:
13+
* - Input: q [1, dim], k [1, kvDim], v [1, kvDim], layerIdx, position
14+
* - Output: attention output [1, dim]
15+
*/
16+
public interface ApertusAttentionBackend<T : DType> {
17+
18+
/**
19+
* Compute attention for one token at the given position.
20+
*
21+
* Q and K have already been QK-normed by the caller.
22+
* This method applies RoPE, stores k/v in the KV cache,
23+
* and returns the attention-weighted output.
24+
*/
25+
public fun attention(
26+
q: Tensor<T, Float>,
27+
k: Tensor<T, Float>,
28+
v: Tensor<T, Float>,
29+
layerIdx: Int,
30+
position: Int
31+
): Tensor<T, Float>
32+
33+
/**
34+
* Compute attention for a batch of tokens starting at [startPos].
35+
*
36+
* Returns null if the backend does not support batch attention,
37+
* in which case the runtime falls back to sequential processing.
38+
*/
39+
public fun batchAttention(
40+
q: Tensor<T, Float>,
41+
k: Tensor<T, Float>,
42+
v: Tensor<T, Float>,
43+
layerIdx: Int,
44+
startPos: Int,
45+
): Tensor<T, Float>? = null
46+
47+
/**
48+
* Reset internal state (KV caches, position tracking, etc.).
49+
*/
50+
public fun reset()
51+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package sk.ainet.models.apertus
2+
3+
/**
4+
* Parses HuggingFace `config.json` into [ApertusModelMetadata].
5+
*
6+
* Uses lightweight manual JSON parsing to avoid external dependencies.
7+
*/
8+
public object ApertusConfigParser {
9+
10+
/**
11+
* Parse a HuggingFace config.json string into ApertusModelMetadata.
12+
*
13+
* Required fields: hidden_size, num_hidden_layers, num_attention_heads,
14+
* num_key_value_heads, intermediate_size, vocab_size.
15+
* Optional: max_position_embeddings, head_dim, rope_theta, hidden_act.
16+
*/
17+
public fun parse(json: String): ApertusModelMetadata {
18+
val map = parseJsonObject(json.trim())
19+
20+
val hiddenSize = map.requireInt("hidden_size")
21+
val numLayers = map.requireInt("num_hidden_layers")
22+
val numHeads = map.requireInt("num_attention_heads")
23+
val numKvHeads = map.intOrNull("num_key_value_heads") ?: numHeads
24+
val intermediateSize = map.requireInt("intermediate_size")
25+
val vocabSize = map.requireInt("vocab_size")
26+
val contextLength = map.intOrNull("max_position_embeddings") ?: 2048
27+
val headDim = map.intOrNull("head_dim") ?: (hiddenSize / numHeads)
28+
val architecture = map.stringOrNull("model_type") ?: "apertus"
29+
val ropeTheta = map.floatOrNull("rope_theta") ?: 12000000f
30+
val qkNorm = map["qk_norm"]?.lowercase()?.let { it == "true" || it == "per_head" } ?: true
31+
val hiddenAct = map.stringOrNull("hidden_act") ?: "xielu"
32+
val tiedEmbeddings = map["tie_word_embeddings"] == "true"
33+
34+
return ApertusModelMetadata(
35+
architecture = architecture,
36+
embeddingLength = hiddenSize,
37+
contextLength = contextLength,
38+
blockCount = numLayers,
39+
headCount = numHeads,
40+
kvHeadCount = numKvHeads,
41+
feedForwardLength = intermediateSize,
42+
ropeDimensionCount = headDim,
43+
vocabSize = vocabSize,
44+
ropeTheta = ropeTheta,
45+
qkNorm = qkNorm,
46+
hiddenAct = hiddenAct,
47+
tiedEmbeddings = tiedEmbeddings
48+
)
49+
}
50+
51+
/**
52+
* Check if config.json indicates tied word embeddings.
53+
*/
54+
public fun isTiedEmbeddings(json: String): Boolean {
55+
val map = parseJsonObject(json.trim())
56+
return map["tie_word_embeddings"] == "true"
57+
}
58+
59+
// ========== Lightweight JSON parsing ==========
60+
61+
private fun parseJsonObject(json: String): Map<String, String> {
62+
if (!json.startsWith("{") || !json.endsWith("}")) {
63+
error("config.json: expected JSON object")
64+
}
65+
val content = json.substring(1, json.length - 1)
66+
val result = mutableMapOf<String, String>()
67+
68+
var i = 0
69+
while (i < content.length) {
70+
while (i < content.length && content[i].isWhitespace()) i++
71+
if (i >= content.length) break
72+
73+
if (content[i] != '"') { i++; continue }
74+
val keyEnd = findStringEnd(content, i)
75+
val key = content.substring(i + 1, keyEnd)
76+
i = keyEnd + 1
77+
78+
while (i < content.length && (content[i].isWhitespace() || content[i] == ':')) i++
79+
80+
val valueStart = i
81+
i = skipValue(content, i)
82+
var value = content.substring(valueStart, i).trim()
83+
if (value.startsWith("\"") && value.endsWith("\"")) {
84+
value = value.substring(1, value.length - 1)
85+
}
86+
result[key] = value
87+
88+
while (i < content.length && (content[i].isWhitespace() || content[i] == ',')) i++
89+
}
90+
91+
return result
92+
}
93+
94+
private fun findStringEnd(s: String, start: Int): Int {
95+
var i = start + 1
96+
while (i < s.length) {
97+
when (s[i]) {
98+
'"' -> return i
99+
'\\' -> i += 2
100+
else -> i++
101+
}
102+
}
103+
return s.length
104+
}
105+
106+
private fun skipValue(s: String, start: Int): Int {
107+
if (start >= s.length) return start
108+
return when (s[start]) {
109+
'"' -> findStringEnd(s, start) + 1
110+
'{' -> findMatching(s, start, '{', '}')
111+
'[' -> findMatching(s, start, '[', ']')
112+
else -> {
113+
var i = start
114+
while (i < s.length && s[i] != ',' && s[i] != '}' && s[i] != ']') i++
115+
i
116+
}
117+
}
118+
}
119+
120+
private fun findMatching(s: String, start: Int, open: Char, close: Char): Int {
121+
var depth = 0
122+
var i = start
123+
var inString = false
124+
while (i < s.length) {
125+
val c = s[i]
126+
when {
127+
inString -> {
128+
if (c == '"') inString = false
129+
else if (c == '\\') i++
130+
}
131+
c == '"' -> inString = true
132+
c == open -> depth++
133+
c == close -> { depth--; if (depth == 0) return i + 1 }
134+
}
135+
i++
136+
}
137+
return s.length
138+
}
139+
140+
private fun Map<String, String>.requireInt(key: String): Int =
141+
this[key]?.toIntOrNull() ?: error("config.json: missing or invalid '$key'")
142+
143+
private fun Map<String, String>.intOrNull(key: String): Int? =
144+
this[key]?.toIntOrNull()
145+
146+
private fun Map<String, String>.floatOrNull(key: String): Float? =
147+
this[key]?.toFloatOrNull()
148+
149+
private fun Map<String, String>.stringOrNull(key: String): String? =
150+
this[key]?.takeIf { it.isNotBlank() }
151+
}

0 commit comments

Comments
 (0)