Skip to content

Commit 6085ef5

Browse files
committed
Update to the last SKaiNET Release and improve UI.
1 parent 642b7ef commit 6085ef5

13 files changed

Lines changed: 413 additions & 125 deletions

File tree

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package sk.ainet.samples.kmp.sinus.approximator
2+
3+
actual val isWasmPlatform: Boolean = false

SinusApproximator/composeApp/src/commonMain/kotlin/sk/ainet/samples/kmp/sinus/approximator/App.kt

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,33 @@ import androidx.compose.ui.Alignment
77
import androidx.compose.ui.Modifier
88
import androidx.compose.ui.unit.dp
99
import sk.ainet.samples.kmp.sinus.approximator.ui.SKaiNETTheme
10+
import sk.ainet.samples.kmp.sinus.approximator.ui.ThemeController
1011

12+
@OptIn(ExperimentalMaterial3Api::class)
1113
@Composable
1214
fun App() {
1315
var selectedTab by remember { mutableStateOf(0) }
1416
val sliderViewModel = remember { SinusSliderViewModel() }
1517
val trainingViewModel = remember { SinusTrainingViewModel() }
1618

17-
SKaiNETTheme {
19+
// Create theme controller only for WASM platform
20+
val themeController = remember { if (isWasmPlatform) ThemeController() else null }
21+
22+
SKaiNETTheme(themeController = themeController) {
1823
Scaffold(
24+
topBar = {
25+
// Show theme toggle only on WASM
26+
if (themeController != null) {
27+
TopAppBar(
28+
title = { Text("Sinus Approximator") },
29+
actions = {
30+
IconButton(onClick = { themeController.toggleTheme() }) {
31+
Text(if (themeController.isDarkTheme) "☀️" else "🌙")
32+
}
33+
}
34+
)
35+
}
36+
},
1937
bottomBar = {
2038
NavigationBar {
2139
NavigationBarItem(
@@ -45,33 +63,63 @@ fun App() {
4563
1 -> SinusTrainingScreen(trainingViewModel)
4664
2 -> {
4765
val modelLoadingState by sliderViewModel.modelLoadingState.collectAsState()
48-
if (modelLoadingState == ModelLoadingState.Success) {
49-
Column(
50-
modifier = Modifier
51-
.fillMaxSize()
52-
.padding(16.dp),
53-
horizontalAlignment = Alignment.CenterHorizontally,
54-
verticalArrangement = Arrangement.spacedBy(16.dp)
55-
) {
56-
Text(
57-
text = "Model Visualization",
58-
style = MaterialTheme.typography.headlineMedium
59-
)
60-
Card(
61-
modifier = Modifier.fillMaxWidth(),
62-
) {
63-
NeuralNetworkVisualization(
64-
model = sliderViewModel.neuralNetworkModel,
65-
modifier = Modifier.padding(16.dp)
66+
Column(
67+
modifier = Modifier
68+
.fillMaxSize()
69+
.padding(16.dp),
70+
horizontalAlignment = Alignment.CenterHorizontally,
71+
verticalArrangement = Arrangement.spacedBy(16.dp)
72+
) {
73+
Text(
74+
text = "Model Visualization",
75+
style = MaterialTheme.typography.headlineMedium
76+
)
77+
78+
when (modelLoadingState) {
79+
ModelLoadingState.Initial -> {
80+
Spacer(modifier = Modifier.weight(1f))
81+
Button(
82+
onClick = { sliderViewModel.loadModel() }
83+
) {
84+
Text("Load Model")
85+
}
86+
Spacer(modifier = Modifier.weight(1f))
87+
}
88+
ModelLoadingState.Loading -> {
89+
Spacer(modifier = Modifier.weight(1f))
90+
CircularProgressIndicator()
91+
Text(
92+
text = "Loading model...",
93+
style = MaterialTheme.typography.bodyMedium,
94+
modifier = Modifier.padding(top = 8.dp)
6695
)
96+
Spacer(modifier = Modifier.weight(1f))
97+
}
98+
ModelLoadingState.Success -> {
99+
Card(
100+
modifier = Modifier.fillMaxWidth(),
101+
) {
102+
NeuralNetworkVisualization(
103+
model = sliderViewModel.neuralNetworkModel,
104+
modifier = Modifier.padding(16.dp)
105+
)
106+
}
107+
}
108+
is ModelLoadingState.Error -> {
109+
Spacer(modifier = Modifier.weight(1f))
110+
Text(
111+
text = "Error: ${(modelLoadingState as ModelLoadingState.Error).message}",
112+
style = MaterialTheme.typography.bodyMedium,
113+
color = MaterialTheme.colorScheme.error
114+
)
115+
Button(
116+
onClick = { sliderViewModel.loadModel() },
117+
modifier = Modifier.padding(top = 8.dp)
118+
) {
119+
Text("Retry")
120+
}
121+
Spacer(modifier = Modifier.weight(1f))
67122
}
68-
}
69-
} else {
70-
Box(
71-
modifier = Modifier.fillMaxSize(),
72-
contentAlignment = Alignment.Center
73-
) {
74-
Text("Please load the model in the Approximation tab first.")
75123
}
76124
}
77125
}

SinusApproximator/composeApp/src/commonMain/kotlin/sk/ainet/samples/kmp/sinus/approximator/LossVisualization.kt

Lines changed: 137 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,164 @@ import androidx.compose.ui.geometry.Offset
1010
import androidx.compose.ui.graphics.Color
1111
import androidx.compose.ui.graphics.Path
1212
import androidx.compose.ui.graphics.drawscope.Stroke
13+
import androidx.compose.ui.graphics.nativeCanvas
1314
import androidx.compose.ui.unit.dp
1415

1516
@Composable
1617
fun LossVisualization(
1718
lossHistory: List<Float>,
19+
totalEpochs: Int,
1820
modifier: Modifier = Modifier
1921
) {
2022
val color = MaterialTheme.colorScheme.error
23+
val axisColor = Color.Gray
24+
25+
// Padding for axis labels
26+
val leftPadding = 50f
27+
val bottomPadding = 25f
28+
val topPadding = 10f
29+
val rightPadding = 10f
2130

2231
Canvas(
2332
modifier = modifier
2433
.height(150.dp)
2534
.fillMaxWidth()
2635
) {
27-
if (lossHistory.size < 2) return@Canvas
36+
val totalWidth = size.width
37+
val totalHeight = size.height
38+
39+
// Plot area dimensions
40+
val plotWidth = totalWidth - leftPadding - rightPadding
41+
val plotHeight = totalHeight - topPadding - bottomPadding
42+
val plotBottom = totalHeight - bottomPadding
43+
val plotTop = topPadding
44+
val plotLeft = leftPadding
45+
val plotRight = totalWidth - rightPadding
46+
47+
// Draw Y-axis
48+
drawLine(
49+
axisColor,
50+
Offset(plotLeft, plotTop),
51+
Offset(plotLeft, plotBottom),
52+
strokeWidth = 1.5f
53+
)
54+
55+
// Draw X-axis
56+
drawLine(
57+
axisColor,
58+
Offset(plotLeft, plotBottom),
59+
Offset(plotRight, plotBottom),
60+
strokeWidth = 1.5f
61+
)
62+
63+
if (lossHistory.size < 2) {
64+
// Draw axis labels even with no data
65+
drawContext.canvas.nativeCanvas.apply {
66+
val textPaint = org.jetbrains.skia.Paint().apply {
67+
this.color = 0xFF808080.toInt()
68+
}
69+
val font = org.jetbrains.skia.Font().apply {
70+
size = 11f
71+
}
72+
// X-axis label
73+
drawString("0", plotLeft - 3f, plotBottom + 18f, font, textPaint)
74+
drawString("$totalEpochs", plotRight - 15f, plotBottom + 18f, font, textPaint)
75+
}
76+
return@Canvas
77+
}
2878

29-
val width = size.width
30-
val height = size.height
31-
3279
val maxLoss = lossHistory.maxOrNull() ?: 1f
3380
val minLoss = lossHistory.minOrNull() ?: 0f
3481
val range = (maxLoss - minLoss).coerceAtLeast(0.0001f)
3582

83+
// Helper functions for coordinate conversion
84+
fun valueToY(loss: Float): Float = plotTop + ((maxLoss - loss) / range * plotHeight)
85+
fun epochToX(epoch: Int): Float = plotLeft + (epoch.toFloat() / (totalEpochs - 1).coerceAtLeast(1) * plotWidth)
86+
87+
// Y-axis tick marks (min, mid, max loss)
88+
val midLoss = (maxLoss + minLoss) / 2
89+
val yTicks = listOf(minLoss, midLoss, maxLoss)
90+
for (tick in yTicks) {
91+
val y = valueToY(tick)
92+
// Tick mark
93+
drawLine(
94+
axisColor,
95+
Offset(plotLeft - 5f, y),
96+
Offset(plotLeft, y),
97+
strokeWidth = 1.5f
98+
)
99+
// Grid line (light)
100+
if (tick != minLoss) {
101+
drawLine(
102+
axisColor.copy(alpha = 0.2f),
103+
Offset(plotLeft, y),
104+
Offset(plotRight, y),
105+
strokeWidth = 1f
106+
)
107+
}
108+
}
109+
110+
// X-axis tick marks (0, mid, total epochs)
111+
val midEpoch = totalEpochs / 2
112+
val xTicks = listOf(0, midEpoch, totalEpochs)
113+
for (tick in xTicks) {
114+
val x = plotLeft + (tick.toFloat() / totalEpochs * plotWidth)
115+
// Tick mark
116+
drawLine(
117+
axisColor,
118+
Offset(x, plotBottom),
119+
Offset(x, plotBottom + 5f),
120+
strokeWidth = 1.5f
121+
)
122+
// Grid line (light) - only for middle
123+
if (tick == midEpoch) {
124+
drawLine(
125+
axisColor.copy(alpha = 0.2f),
126+
Offset(x, plotTop),
127+
Offset(x, plotBottom),
128+
strokeWidth = 1f
129+
)
130+
}
131+
}
132+
133+
// Draw axis labels
134+
drawContext.canvas.nativeCanvas.apply {
135+
val textPaint = org.jetbrains.skia.Paint().apply {
136+
this.color = 0xFF808080.toInt()
137+
}
138+
val font = org.jetbrains.skia.Font().apply {
139+
size = 11f
140+
}
141+
142+
// Format loss values (cross-platform)
143+
fun formatLoss(loss: Float): String {
144+
val str = loss.toString()
145+
val dotIndex = str.indexOf('.')
146+
return if (dotIndex == -1) {
147+
str
148+
} else {
149+
val decimals = if (loss < 0.01f) 4 else if (loss < 1f) 3 else 2
150+
val endIndex = minOf(dotIndex + decimals + 1, str.length)
151+
str.substring(0, endIndex)
152+
}
153+
}
154+
155+
// Y-axis labels
156+
drawString(formatLoss(maxLoss), 2f, valueToY(maxLoss) + 4f, font, textPaint)
157+
drawString(formatLoss(midLoss), 2f, valueToY(midLoss) + 4f, font, textPaint)
158+
drawString(formatLoss(minLoss), 2f, valueToY(minLoss) + 4f, font, textPaint)
159+
160+
// X-axis labels (epochs)
161+
drawString("0", plotLeft - 3f, plotBottom + 18f, font, textPaint)
162+
drawString("$midEpoch", plotLeft + (midEpoch.toFloat() / totalEpochs * plotWidth) - 10f, plotBottom + 18f, font, textPaint)
163+
drawString("$totalEpochs", plotRight - 15f, plotBottom + 18f, font, textPaint)
164+
}
165+
166+
// Draw loss curve
36167
val path = Path().apply {
37168
lossHistory.forEachIndexed { index, loss ->
38-
val x = index.toFloat() / (lossHistory.size - 1) * width
39-
val y = height - ((loss - minLoss) / range * height)
169+
val x = epochToX(index)
170+
val y = valueToY(loss)
40171
if (index == 0) {
41172
moveTo(x, y)
42173
} else {
@@ -50,13 +181,5 @@ fun LossVisualization(
50181
color = color,
51182
style = Stroke(width = 2.dp.toPx())
52183
)
53-
54-
// Draw baseline
55-
drawLine(
56-
color = Color.Gray.copy(alpha = 0.5f),
57-
start = Offset(0f, height),
58-
end = Offset(width, height),
59-
strokeWidth = 1f
60-
)
61184
}
62185
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package sk.ainet.samples.kmp.sinus.approximator
2+
3+
/**
4+
* Platform detection for conditional behavior.
5+
*/
6+
expect val isWasmPlatform: Boolean

SinusApproximator/composeApp/src/commonMain/kotlin/sk/ainet/samples/kmp/sinus/approximator/SinusTrainingScreen.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ fun SinusTrainingScreen(viewModel: SinusTrainingViewModel) {
6161
Text("Loss History", style = MaterialTheme.typography.titleSmall)
6262
LossVisualization(
6363
lossHistory = trainingState.lossHistory,
64-
modifier = Modifier.fillMaxWidth().height(100.dp)
64+
totalEpochs = trainingState.totalEpochs,
65+
modifier = Modifier.fillMaxWidth().height(150.dp)
6566
)
6667
}
6768
}

0 commit comments

Comments
 (0)