Skip to content

Commit c80d69e

Browse files
committed
Fix failing tests
Relate-To: #112, #113
1 parent 4bd5548 commit c80d69e

1 file changed

Lines changed: 27 additions & 18 deletions

File tree

  • skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist

skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ai/net/io/data/mnist/MNISTLoaderTest.kt renamed to skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1-
package sk.ai.net.io.data.mnist
1+
package sk.ainet.io.data.mnist
22

33
import kotlinx.coroutines.runBlocking
4+
import sk.ainet.data.mnist.MNISTConstants
5+
import sk.ainet.data.mnist.MNISTLoaderConfig
6+
import sk.ainet.data.mnist.MNISTLoaderFactory
47
import kotlin.test.Test
58
import kotlin.test.assertEquals
69
import kotlin.test.assertNotNull
710
import kotlin.test.assertTrue
811
import java.io.File
12+
import kotlin.io.path.ExperimentalPathApi
13+
import kotlin.io.path.absolute
14+
import kotlin.io.path.createTempDirectory
15+
import kotlin.io.path.deleteRecursively
916

1017
/**
1118
* Tests for the MNIST loader.
@@ -22,27 +29,27 @@ class MNISTLoaderTest {
2229
try {
2330
// Create a loader with the temporary directory
2431
val loader = MNISTLoaderFactory.create(tempDir.absolutePath)
25-
32+
2633
// Load the training data
2734
val dataset = loader.loadTrainingData()
28-
35+
2936
// Verify the dataset
3037
assertNotNull(dataset)
3138
assertTrue(dataset.size > 0)
3239
assertEquals(60000, dataset.size) // MNIST training set has 60,000 images
33-
40+
3441
// Verify the first image
3542
val firstImage = dataset.images[0]
3643
assertNotNull(firstImage)
3744
assertEquals(MNISTConstants.IMAGE_PIXELS, firstImage.image.size)
3845
assertTrue(firstImage.label >= 0 && firstImage.label <= 9)
39-
46+
4047
// Verify that the cache files were created
4148
val trainImagesFile = File(tempDir, MNISTConstants.TRAIN_IMAGES_FILENAME.removeSuffix(".gz"))
4249
val trainLabelsFile = File(tempDir, MNISTConstants.TRAIN_LABELS_FILENAME.removeSuffix(".gz"))
4350
assertTrue(trainImagesFile.exists())
4451
assertTrue(trainLabelsFile.exists())
45-
52+
4653
// Load the data again to test caching
4754
val cachedDataset = loader.loadTrainingData()
4855
assertEquals(dataset.size, cachedDataset.size)
@@ -55,40 +62,42 @@ class MNISTLoaderTest {
5562
/**
5663
* Tests loading the MNIST test dataset.
5764
*/
65+
@OptIn(ExperimentalPathApi::class)
5866
@Test
5967
fun testLoadTestData() = runBlocking {
6068
// Create a temporary directory for caching
61-
val tempDir = createTempDir("mnist-test")
69+
val tempDirPath = createTempDirectory("mnist-test")
70+
val tempDir = tempDirPath.absolute().toString()
6271
try {
6372
// Create a loader with the temporary directory
64-
val loader = MNISTLoaderFactory.create(tempDir.absolutePath)
65-
73+
val loader = MNISTLoaderFactory.create(tempDir)
74+
6675
// Load the test data
6776
val dataset = loader.loadTestData()
68-
77+
6978
// Verify the dataset
7079
assertNotNull(dataset)
7180
assertTrue(dataset.size > 0)
7281
assertEquals(10000, dataset.size) // MNIST test set has 10,000 images
73-
82+
7483
// Verify the first image
7584
val firstImage = dataset.images[0]
7685
assertNotNull(firstImage)
7786
assertEquals(MNISTConstants.IMAGE_PIXELS, firstImage.image.size)
7887
assertTrue(firstImage.label >= 0 && firstImage.label <= 9)
79-
88+
8089
// Verify that the cache files were created
8190
val testImagesFile = File(tempDir, MNISTConstants.TEST_IMAGES_FILENAME.removeSuffix(".gz"))
8291
val testLabelsFile = File(tempDir, MNISTConstants.TEST_LABELS_FILENAME.removeSuffix(".gz"))
8392
assertTrue(testImagesFile.exists())
8493
assertTrue(testLabelsFile.exists())
85-
94+
8695
// Load the data again to test caching
8796
val cachedDataset = loader.loadTestData()
8897
assertEquals(dataset.size, cachedDataset.size)
8998
} finally {
9099
// Clean up
91-
tempDir.deleteRecursively()
100+
tempDirPath.deleteRecursively()
92101
}
93102
}
94103

@@ -99,13 +108,13 @@ class MNISTLoaderTest {
99108
fun testDatasetSubset() = runBlocking {
100109
// Create a loader with the default configuration
101110
val loader = MNISTLoaderFactory.create()
102-
111+
103112
// Load the training data
104113
val dataset = loader.loadTrainingData()
105-
114+
106115
// Create a subset
107116
val subset = dataset.subset(0, 100)
108-
117+
109118
// Verify the subset
110119
assertEquals(100, subset.size)
111120
assertEquals(dataset.images[0], subset.images[0])
@@ -123,7 +132,7 @@ class MNISTLoaderTest {
123132
useCache = false
124133
)
125134
val loader = MNISTLoaderFactory.create(config)
126-
135+
127136
// Verify that the loader was created successfully
128137
assertNotNull(loader)
129138
}

0 commit comments

Comments
 (0)