Skip to content

Commit 6c548d7

Browse files
Merge pull request #185 from SKaiNET-developers/feature/transformer-core
Extract transformer-core: NN primitives reusable on all targets (incl. androidNative)
2 parents b5c3fe1 + 5baae89 commit 6c548d7

20 files changed

Lines changed: 99 additions & 3 deletions

File tree

llm-core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ kotlin {
4848
// versions. Bumping the engine is then a one-line change at the
4949
// top of `gradle/libs.versions.toml`.
5050
implementation(project.dependencies.platform(project(":llm-bom")))
51+
api(project(":transformer-core"))
5152
implementation(libs.skainet.lang.core)
5253
implementation(libs.skainet.compile.dag)
5354
implementation(libs.skainet.compile.opt)

llm-core/src/commonMain/kotlin/sk/ainet/apps/llm/HybridTransformerBlock.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,11 @@ public class HybridTransformerBlock<T : DType, V>(
168168
// the same name — so MHA can't gate its own dump on the block id.
169169
// Toggle the static flag from here, where we know which block we're in.
170170
val isMhaCall = dumpMha && module is MultiHeadAttention<*, *>
171-
if (isMhaCall) sk.ainet.lang.nn.transformer.MultiHeadAttentionDiag.shouldDumpThisCall = true
171+
if (isMhaCall) {
172+
// wire transformer-core's MHA diagnostic sink to llm-core's platform dumpStats (idempotent)
173+
sk.ainet.lang.nn.transformer.mhaStatSink = { l, t -> sk.ainet.apps.llm.diag.dumpStats(l, t) }
174+
sk.ainet.lang.nn.transformer.MultiHeadAttentionDiag.shouldDumpThisCall = true
175+
}
172176
tmp = module.forward(tmp, ctx)
173177
if (isMhaCall) sk.ainet.lang.nn.transformer.MultiHeadAttentionDiag.shouldDumpThisCall = false
174178
outputs[i + 1] = tmp

settings.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ if (providers.gradleProperty("useLocalSkainet").orNull == "true") {
2424
rootProject.name = "SKaiNET-transformers"
2525

2626
include("llm-api")
27+
include("transformer-core")
2728
include("llm-core")
2829
include("llm-agent")
2930
include("llm-providers")

transformer-core/README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# transformer-core
2+
3+
Framework NN primitives — attention, the KV-cache family, embedding, norms, RoPE, SwiGLU/GeGLU FFN,
4+
residual, linear projection — extracted from `llm-core` so they build on the **full Kotlin target matrix
5+
including `androidNativeArm32/Arm64`** (the on-device ARM path). Depends only on `skainet-lang-core`
6+
(which has androidNative); no io/compile/backend deps.
7+
8+
`llm-core` `api`-depends on this module and **re-exports** it, so existing consumers are unaffected.
9+
ARM-native consumers (e.g. `skainet-whisper-kmp`) depend on `transformer-core` directly and reuse
10+
KV-cache/attention instead of reimplementing.
11+
12+
## Why
13+
`llm-core`'s primitives only need `lang-core`, but were trapped there: `llm-core`'s *other* deps
14+
(`io-gguf`, `io-core`, `compile-*`, `backend-cpu`) lack androidNative, so ARM-native consumers couldn't
15+
depend on it. The primitives are **dtype-agnostic** (just call `ops.*`), so this target generalization is
16+
orthogonal to the quant/dtype generalization (issue #178) — they meet cleanly at these primitives.
17+
18+
## What moved (15 files, lang-core-only)
19+
`transformer/*` (KVCache, RoPE, ResidualAdd, MultiHeadAttention, GeGLUFFN, SwiGLUFFN, XIELUActivation,
20+
LayerScalarMul, LinearProjection, VoidDense), `layers/*` (Embedding*), `normalization/RMSNormalization`,
21+
`dsl/TransformerDsl`. **Kept in `llm-core`:** `dsl/decoder/*` (DecoderTransformerNetwork needs
22+
`apps.llm.HybridTransformerBlock`, which is compile-opt-coupled).
23+
24+
One back-reference decoupled: `MultiHeadAttention`'s diagnostic `dumpStats` → a settable `mhaStatSink`
25+
(default no-op) that `HybridTransformerBlock` wires to llm-core's platform `dumpStats` (no behaviour lost).
26+
27+
## Verified
28+
`:transformer-core:` compiles for jvm + androidNativeArm32 + arm64; `:llm-core:jvmTest` green (5/5) via
29+
the re-export.
30+
31+
## Landing (for the maintainer)
32+
Branch `feature/transformer-core` was cut from `release/0.31.0`. To land on `develop` (which has #178's
33+
merged #179/#180):
34+
1. `git fetch origin && git rebase origin/develop`**no conflicts expected on the moved files**: #178's
35+
merged work is in the model layer (`GemmaPackedWeights`) + engine (`ops.transpose` Q8_0/Q4_0), not these
36+
primitives. (Verified against local refs; re-check against fresh `develop`.)
37+
2. Build the full target matrix + `:llm-core:` tests; PR; CI-publish; bump the `skainet`/transformers pins.
38+
3. **Note for future quant work:** the pre-transpose-marker (#178 "Solution C") will land in
39+
`LinearProjection.kt`, which now lives **here**, not `llm-core`. And `RowDequantSource` + packed-weight
40+
packing (today in `sk.ainet.models.gemma`) are the next candidates to hoist into a shared `quant` layer
41+
or this module — that's what makes quant reusable across models *and* whisper.

transformer-core/build.gradle.kts

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl
2+
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
3+
4+
plugins {
5+
alias(libs.plugins.kotlinMultiplatform)
6+
alias(libs.plugins.androidMultiplatformLibrary)
7+
alias(libs.plugins.vanniktech.mavenPublish)
8+
}
9+
10+
// Framework NN primitives (attention, KV-cache family, embedding, norms, RoPE, FFNs) extracted from
11+
// llm-core so they build on the FULL target matrix — including androidNative (the 32-bit box + phones).
12+
// Depends ONLY on skainet-lang-core (which has androidNative); no io/compile/backend deps. llm-core
13+
// re-exports this module, so existing consumers are unaffected; ARM-native consumers depend on it directly.
14+
kotlin {
15+
android {
16+
namespace = "sk.ainet.lang.nn"
17+
compileSdk = libs.versions.android.compileSdk.get().toInt()
18+
minSdk = libs.versions.android.minSdk.get().toInt()
19+
compilerOptions { jvmTarget.set(JvmTarget.JVM_11) }
20+
}
21+
22+
jvm()
23+
androidNativeArm32()
24+
androidNativeArm64()
25+
iosArm64()
26+
iosSimulatorArm64()
27+
linuxX64()
28+
linuxArm64()
29+
macosArm64()
30+
js { browser() }
31+
@OptIn(ExperimentalWasmDsl::class) wasmJs { browser() }
32+
@OptIn(ExperimentalWasmDsl::class) wasmWasi { nodejs() }
33+
34+
sourceSets {
35+
commonMain.dependencies {
36+
implementation(project.dependencies.platform(project(":llm-bom")))
37+
api(libs.skainet.lang.core) // public API is lang-core-typed (Tensor/Module/ExecutionContext)
38+
}
39+
commonTest.dependencies {
40+
implementation(libs.kotlin.test)
41+
}
42+
}
43+
}

llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/TransformerDsl.kt renamed to transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/TransformerDsl.kt

File renamed without changes.

llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/layers/Embedding.kt renamed to transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/layers/Embedding.kt

File renamed without changes.

llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/layers/EmbeddingAdapter.kt renamed to transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/layers/EmbeddingAdapter.kt

File renamed without changes.

llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/layers/EmbeddingParams.kt renamed to transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/layers/EmbeddingParams.kt

File renamed without changes.

llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/RMSNormalization.kt renamed to transformer-core/src/commonMain/kotlin/sk/ainet/lang/nn/normalization/RMSNormalization.kt

File renamed without changes.

0 commit comments

Comments
 (0)