Skip to content

Commit d1e0d51

Browse files
committed
Add KAN Sinus approximator UI.
1 parent ad5ba74 commit d1e0d51

6 files changed

Lines changed: 165 additions & 30 deletions

File tree

SinusApproximator/composeApp/src/commonMain/kotlin/sk/ai/net/samples/kmp/sinus/approximator/SinusSliderViewModel.kt

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import kotlinx.coroutines.flow.MutableStateFlow
1010
import kotlinx.coroutines.flow.StateFlow
1111
import kotlinx.coroutines.flow.asStateFlow
1212
import kotlinx.coroutines.launch
13-
import kotlinx.io.Source
13+
import sk.ainet.app.samples.sinus.KanSinusCalculator
1414
import sk.ainet.app.samples.sinus.MLPSinusCalculator
1515
import kotlin.math.abs
1616
import kotlin.math.pow
@@ -26,24 +26,40 @@ sealed interface ModelLoadingState {
2626

2727
class SinusSliderViewModel() : ViewModel() {
2828
private val calculator = MLPSinusCalculator()
29+
private val kanCalculator = KanSinusCalculator()
30+
2931
private val _modelLoadingState = MutableStateFlow<ModelLoadingState>(ModelLoadingState.Initial)
3032
val modelLoadingState: StateFlow<ModelLoadingState> = _modelLoadingState.asStateFlow()
3133

32-
// Expose the neural network model
33-
val neuralNetworkModel get() = calculator.model
34+
// Expose the neural network model (use KAN calculator model)
35+
val neuralNetworkModel get() = kanCalculator.model
3436

3537
var sliderValue by mutableStateOf(0f)
3638
private set
3739

3840
var sinusValue by mutableStateOf(0.0)
3941
private set
4042

43+
// For backward compatibility (kept but not used by UI anymore). Defaults to KAN value.
4144
var modelSinusValue by mutableStateOf(0.0f)
4245
private set
4346

47+
// Both models at once
48+
var modelSinusValueKan by mutableStateOf(0.0f)
49+
private set
50+
51+
var modelSinusValueMlp by mutableStateOf(0.0f)
52+
private set
53+
4454
var errorValue by mutableStateOf(0.0)
4555
private set
4656

57+
var errorValueKan by mutableStateOf(0.0)
58+
private set
59+
60+
var errorValueMlp by mutableStateOf(0.0)
61+
private set
62+
4763
// Formatted values for display
4864
var formattedAngle by mutableStateOf("0.0000")
4965
private set
@@ -57,6 +73,19 @@ class SinusSliderViewModel() : ViewModel() {
5773
var formattedErrorValue by mutableStateOf("0.000000")
5874
private set
5975

76+
// New formatted values for dual display
77+
var formattedModelSinusValueKan by mutableStateOf("0.000000")
78+
private set
79+
80+
var formattedModelSinusValueMlp by mutableStateOf("0.000000")
81+
private set
82+
83+
var formattedErrorValueKan by mutableStateOf("0.000000")
84+
private set
85+
86+
var formattedErrorValueMlp by mutableStateOf("0.000000")
87+
private set
88+
6089
private fun Double.formatDecimal(decimals: Int): String {
6190
val factor = 10.0.pow(decimals.toDouble())
6291
return (round(this * factor) / factor).toString()
@@ -71,23 +100,39 @@ class SinusSliderViewModel() : ViewModel() {
71100
formattedSinusValue = sinusValue.formatDecimal(6)
72101
formattedModelSinusValue = modelSinusValue.formatDecimal(6)
73102
formattedErrorValue = errorValue.formatDecimal(6)
103+
104+
// New ones
105+
formattedModelSinusValueKan = modelSinusValueKan.formatDecimal(6)
106+
formattedModelSinusValueMlp = modelSinusValueMlp.formatDecimal(6)
107+
formattedErrorValueKan = errorValueKan.formatDecimal(6)
108+
formattedErrorValueMlp = errorValueMlp.formatDecimal(6)
74109
}
75110

76111
fun updateSliderValue(value: Float) {
77112
sliderValue = value
78113
sinusValue = sin(value.toDouble())
79-
modelSinusValue = calculator.calculate(value.toFloat())
80-
errorValue = abs(sinusValue - modelSinusValue)
114+
// Compute both models
115+
modelSinusValueKan = kanCalculator.calculate(value)
116+
modelSinusValueMlp = calculator.calculate(value)
117+
118+
// Keep legacy fields aligned to KAN for compatibility
119+
modelSinusValue = modelSinusValueKan
120+
121+
// Errors
122+
errorValueKan = abs(sinusValue - modelSinusValueKan)
123+
errorValueMlp = abs(sinusValue - modelSinusValueMlp)
124+
// Legacy single error equals KAN error for now
125+
errorValue = errorValueKan
81126
updateFormattedValues()
82127
}
83128

84129
fun loadModel() {
85130
viewModelScope.launch {
86131
_modelLoadingState.value = ModelLoadingState.Loading
87132
try {
88-
launch(Dispatchers.Default) {
89-
calculator.loadModel()
90-
}.join()
133+
// Load both models (currently no-ops as they are preloaded with weights)
134+
kanCalculator.loadModel()
135+
calculator.loadModel()
91136
_modelLoadingState.value = ModelLoadingState.Success
92137
// Recalculate values after model is loaded
93138
updateSliderValue(sliderValue)

SinusApproximator/composeApp/src/commonMain/kotlin/sk/ai/net/samples/kmp/sinus/approximator/SinusVisualization.kt

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ import kotlin.math.sin
1717
fun SinusVisualization(
1818
sliderValue: Float,
1919
actualSinus: Double,
20-
approximatedSinus: Float,
20+
approximatedSinusKan: Float,
21+
approximatedSinusMlp: Float,
2122
modifier: Modifier = Modifier
2223
) {
2324
val primary = MaterialTheme.colorScheme.primary
2425
val secondary = MaterialTheme.colorScheme.secondary
26+
val tertiary = MaterialTheme.colorScheme.tertiary
2527
val error = MaterialTheme.colorScheme.error
2628

2729
Canvas(
@@ -56,20 +58,31 @@ fun SinusVisualization(
5658

5759
// Draw points for actual and approximated values
5860
val actualY = centerY - (actualSinus * centerY).toFloat()
59-
val approximatedY = centerY - (approximatedSinus * centerY).toFloat()
61+
val approximatedYKan = centerY - (approximatedSinusKan * centerY).toFloat()
62+
val approximatedYMlp = centerY - (approximatedSinusMlp * centerY).toFloat()
6063

61-
// Draw line between points to show error
64+
// Draw lines between points to show errors (use same colors as the dots)
6265
drawLine(
63-
error.copy(alpha = 0.5f),
66+
secondary.copy(alpha = 0.5f),
6467
Offset(x, actualY),
65-
Offset(x, approximatedY),
68+
Offset(x, approximatedYKan),
69+
strokeWidth = 2f,
70+
cap = StrokeCap.Round
71+
)
72+
drawLine(
73+
tertiary.copy(alpha = 0.5f),
74+
Offset(x, actualY),
75+
Offset(x, approximatedYMlp),
6676
strokeWidth = 2f,
6777
cap = StrokeCap.Round
6878
)
6979

7080
// Draw points
7181
drawCircle(primary, 6f, Offset(x, actualY))
72-
drawCircle(secondary, 6f, Offset(x, approximatedY))
82+
// KAN dot
83+
drawCircle(secondary, 6f, Offset(x, approximatedYKan))
84+
// MLP dot
85+
drawCircle(tertiary, 6f, Offset(x, approximatedYMlp))
7386
}
7487
}
7588

SinusApproximator/composeApp/src/commonMain/kotlin/sk/ai/net/samples/kmp/sinus/approximator/SliderScreen.kt

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ fun SinusSliderScreen() {
2929
SinusVisualization(
3030
sliderValue = viewModel.sliderValue,
3131
actualSinus = viewModel.sinusValue,
32-
approximatedSinus = viewModel.modelSinusValue,
32+
approximatedSinusKan = viewModel.modelSinusValueKan,
33+
approximatedSinusMlp = viewModel.modelSinusValueMlp,
3334
modifier = Modifier.padding(vertical = 16.dp)
3435
)
3536

@@ -39,8 +40,9 @@ fun SinusSliderScreen() {
3940
) {
4041
Column(
4142
modifier = Modifier.padding(16.dp),
42-
verticalArrangement = Arrangement.spacedBy(8.dp)
43+
verticalArrangement = Arrangement.spacedBy(12.dp)
4344
) {
45+
// Header values spanning full width
4446
Text(
4547
text = "Angle: ${viewModel.formattedAngle}",
4648
style = MaterialTheme.typography.titleSmall
@@ -50,16 +52,43 @@ fun SinusSliderScreen() {
5052
style = MaterialTheme.typography.bodyMedium,
5153
color = MaterialTheme.colorScheme.primary
5254
)
53-
Text(
54-
text = "Approximated sin: ${viewModel.formattedModelSinusValue}",
55-
style = MaterialTheme.typography.bodyMedium,
56-
color = MaterialTheme.colorScheme.secondary
57-
)
58-
Text(
59-
text = "Error: ${viewModel.formattedErrorValue}",
60-
style = MaterialTheme.typography.bodySmall,
61-
color = MaterialTheme.colorScheme.error
62-
)
55+
56+
// Two columns: KAN on the left, MLP on the right
57+
Row(
58+
modifier = Modifier.fillMaxWidth(),
59+
horizontalArrangement = Arrangement.spacedBy(16.dp)
60+
) {
61+
Column(
62+
modifier = Modifier.weight(1f),
63+
verticalArrangement = Arrangement.spacedBy(4.dp)
64+
) {
65+
Text(
66+
text = "KAN approximated sin: ${viewModel.formattedModelSinusValueKan}",
67+
style = MaterialTheme.typography.bodyMedium,
68+
color = MaterialTheme.colorScheme.secondary
69+
)
70+
Text(
71+
text = "KAN error: ${viewModel.formattedErrorValueKan}",
72+
style = MaterialTheme.typography.bodySmall,
73+
color = MaterialTheme.colorScheme.secondary
74+
)
75+
}
76+
Column(
77+
modifier = Modifier.weight(1f),
78+
verticalArrangement = Arrangement.spacedBy(4.dp)
79+
) {
80+
Text(
81+
text = "MLP approximated sin: ${viewModel.formattedModelSinusValueMlp}",
82+
style = MaterialTheme.typography.bodyMedium,
83+
color = MaterialTheme.colorScheme.tertiary
84+
)
85+
Text(
86+
text = "MLP error: ${viewModel.formattedErrorValueMlp}",
87+
style = MaterialTheme.typography.bodySmall,
88+
color = MaterialTheme.colorScheme.tertiary
89+
)
90+
}
91+
}
6392
}
6493
}
6594

SinusApproximator/gradle/libs.versions.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ kotlin = "2.2.21"
1919
kotlinx-coroutines = "1.10.2"
2020
ktor = "3.1.3"
2121
logback = "1.5.18"
22-
skainet = "0.2.0"
23-
kotlinxIo = "0.8.0"
22+
skainet = "0.3.0"
23+
kotlinxIo = "0.8.2"
2424

2525
[libraries]
2626
kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" }
@@ -34,10 +34,12 @@ kotlinx-io-bytestring = { module = "org.jetbrains.kotlinx:kotlinx-io-bytestring"
3434
# SKaiNET
3535
skainet-lang-core = { module = "sk.ainet.core:skainet-lang-core", version.ref = "skainet" }
3636
skainet-lang-models = { module = "sk.ainet.core:skainet-lang-models", version.ref = "skainet" }
37+
skainet-lang-kan = { module = "sk.ainet.core:skainet-lang-kan", version.ref = "skainet" }
3738
skainet-compile-core = { module = "sk.ainet.core:skainet-compile-core", version.ref = "skainet" }
3839
skainet-backend-cpu = { module = "sk.ainet.core:skainet-backend-cpu", version.ref = "skainet" }
3940
skainet-backend-cpu-jvm = { module = "sk.ainet.core:skainet-backend-cpu-jvm", version.ref = "skainet" }
40-
41+
skainet-data-api = { module = "sk.ainet.core:skainet-data-api", version.ref = "skainet" }
42+
skainet-data-simple = { module = "sk.ainet.core:skainet-data-basic", version.ref = "skainet" }
4143

4244
[plugins]
4345
androidApplication = { id = "com.android.application", version.ref = "agp" }

SinusApproximator/shared/build.gradle.kts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import org.gradle.kotlin.dsl.implementation
12
import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl
23

34
plugins {
@@ -36,6 +37,12 @@ kotlin {
3637
implementation(libs.skainet.lang.models)
3738
implementation(libs.skainet.compile.core)
3839
implementation(libs.skainet.backend.cpu)
40+
implementation(libs.skainet.lang.core)
41+
implementation(libs.skainet.lang.models)
42+
implementation(libs.skainet.lang.kan)
43+
implementation(libs.skainet.data.api)
44+
implementation(libs.skainet.data.simple)
45+
3946
}
4047

4148
wasmJsMain.dependencies {

SinusApproximator/shared/src/commonMain/kotlin/sk/ainet/app/samples/sinus/MLPSinusCalculator.kt

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import sk.ainet.lang.nn.network
1010
import sk.ainet.lang.tensor.dsl.tensor
1111
import sk.ainet.lang.tensor.relu
1212
import sk.ainet.lang.types.FP32
13+
import sk.ainet.lang.kan.examples.SineKanPretrained
1314

1415

1516
class SineNN(private val ctx: ExecutionContext) {
@@ -90,7 +91,7 @@ class SineNN(private val ctx: ExecutionContext) {
9091
}
9192
}
9293

93-
model.forward (inputTensor, ctx).data[0, 0]
94+
model.forward(inputTensor, ctx).data[0, 0]
9495
}
9596
}
9697

@@ -107,4 +108,42 @@ class MLPSinusCalculator() : SinusCalculator {
107108
}
108109
}
109110

111+
class KanSinusCalculator() : SinusCalculator {
112+
private val ctx = DirectCpuExecutionContext()
113+
114+
fun sk.ainet.lang.nn.Module<FP32, Float>.calcSine(ctx: ExecutionContext, angle: Float): Float {
115+
val model_: sk.ainet.lang.nn.Module<FP32, kotlin.Float> = this
116+
return computation<Float>(ctx) {
117+
// Create a simple input tensor compatible with the model's expected input size (1)
118+
model_.forward(
119+
data<FP32, Float>(ctx) {
120+
tensor<FP32, Float>() {
121+
// Using shape(1, 1) to represent a single scalar input in 2D form
122+
shape(1, 1) {
123+
fromArray(
124+
floatArrayOf(angle)
125+
)
126+
}
127+
}
128+
}, ctx
129+
).data[0, 0]
130+
}
131+
}
132+
133+
134+
val _model = SineKanPretrained.create(ctx)
135+
val model = _model
136+
137+
138+
override fun calculate(angle: Float): Float = _model.calcSine(ctx, angle)
139+
140+
override suspend fun loadModel() {
141+
// TODO model has pretrained weights as a part of model
142+
}
143+
}
144+
145+
/*
146+
147+
*/
148+
110149

0 commit comments

Comments
 (0)