Skip to content

Commit d2950c6

Browse files
Merge pull request #634 from SKaiNET-developers/feature/test-hygiene-pr2
Clarify ignored tests and enable BatchNorm coverage
2 parents 327a8fc + fafec12 commit d2950c6

5 files changed

Lines changed: 14 additions & 15 deletions

File tree

skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/lang/nn/normalization/BatchNormalizationTest.kt

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import kotlin.test.Test
44
import kotlin.test.assertEquals
55
import kotlin.test.assertFailsWith
66
import kotlin.test.assertNotNull
7-
import kotlin.test.Ignore
87
import sk.ainet.context.DirectCpuExecutionContext
8+
import sk.ainet.context.Phase
99
import sk.ainet.lang.tensor.Shape
1010
import sk.ainet.lang.types.FP32
1111
import sk.ainet.lang.tensor.Tensor
@@ -37,33 +37,32 @@ class BatchNormalizationTest {
3737
}
3838
}
3939

40-
@Ignore
4140
@Test
4241
fun train_then_eval_works_and_preserves_shape() {
43-
val exec = DirectCpuExecutionContext()
44-
val x = makeInput2x2(exec)
42+
val trainExec = DirectCpuExecutionContext(phase = Phase.TRAIN)
43+
val evalExec = DirectCpuExecutionContext(phase = Phase.EVAL)
44+
val x = makeInput2x2(trainExec)
4545
val bn = BatchNormalization<FP32, Float>(
4646
numFeatures = 2,
4747
affine = false,
4848
name = "bn"
4949
)
5050
// training pass initializes running stats
5151
bn.train()
52-
val yTrain = bn.forward(x, exec)
52+
val yTrain = bn.forward(x, trainExec)
5353
assertNotNull(yTrain)
5454
assertEquals(x.shape, yTrain.shape)
5555

5656
// eval should now work using running stats
5757
bn.eval()
58-
val yEval = bn.forward(x, exec)
58+
val yEval = bn.forward(x, evalExec)
5959
assertNotNull(yEval)
6060
assertEquals(x.shape, yEval.shape)
6161
}
6262

63-
@Ignore
6463
@Test
6564
fun simple_2x2_batch_is_normalized_per_channel() {
66-
val exec = DirectCpuExecutionContext()
65+
val exec = DirectCpuExecutionContext(phase = Phase.TRAIN)
6766
val x = makeInput2x2(exec)
6867
val bn = BatchNormalization<FP32, Float>(
6968
numFeatures = 2,

skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/DefaultExecutionContextTest.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import kotlin.test.Test
44
import kotlin.test.Ignore
55
import kotlin.test.assertTrue
66

7-
@Ignore
7+
@Ignore("GraphExecution DSL tests are parked until the API drift is resolved")
88
class GraphExecutionDSLTest {
99
@Test
1010
fun placeholder() {
@@ -298,4 +298,4 @@ class GraphExecutionDSLTest {
298298
299299
}
300300
*/
301-
}
301+
}

skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/MnistMplGraphvizTest.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import kotlin.test.assertTrue
1919
* val graph = model.toGraph()
2020
* graph.toGraphviz()
2121
*/
22-
@Ignore
22+
@Ignore("Graphviz snippet is parked until the MnistMpl graph API drift is resolved")
2323
class MnistMplGraphvizTest {
2424

2525

@@ -30,4 +30,4 @@ class MnistMplGraphvizTest {
3030
// Keeping body minimal to allow compilation when @Ignore handling differs across targets.
3131
assertTrue(true)
3232
}
33-
}
33+
}

skainet-compile/skainet-compile-dag/src/jvmTest/kotlin/sk/ainet/graph/TapeToGraphUnitTests.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import kotlin.test.Test
44
import kotlin.test.Ignore
55
import kotlin.test.assertTrue
66

7-
@Ignore
7+
@Ignore("Placeholder suite parked until the JVM-specific execution helper is migrated")
88
class TapeToGraphUnitTests {
99
@Test
1010
fun placeholder() {

skainet-io/skainet-io-onnx/src/jvmTest/kotlin/sk/ainet/io/onnx/OnnxResourceReadTest.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import kotlin.test.Ignore
1111

1212
class OnnxResourceReadTest {
1313

14-
@Ignore
14+
@Ignore("Requires run14.onnx test fixture, which is not checked into the repository")
1515
@Test
1616
fun `read run14 onnx from resources and build graph view`() {
1717
val inputStream: InputStream = requireNotNull(javaClass.getResourceAsStream("/run14.onnx")) {
@@ -38,7 +38,7 @@ class OnnxResourceReadTest {
3838
)
3939
}
4040

41-
@Ignore
41+
@Ignore("Requires run14.onnx test fixture, which is not checked into the repository")
4242
@Test
4343
fun `run14 onnx ops are covered by importer mapping`() {
4444
val bytes = loadResourceBytes("run14.onnx")

0 commit comments

Comments
 (0)