1- package sk.ai.net .io.data.mnist
1+ package sk.ainet .io.data.mnist
22
33import kotlinx.coroutines.runBlocking
4+ import sk.ainet.data.mnist.MNISTConstants
5+ import sk.ainet.data.mnist.MNISTLoaderConfig
6+ import sk.ainet.data.mnist.MNISTLoaderFactory
47import kotlin.test.Test
58import kotlin.test.assertEquals
69import kotlin.test.assertNotNull
710import kotlin.test.assertTrue
811import java.io.File
12+ import kotlin.io.path.ExperimentalPathApi
13+ import kotlin.io.path.absolute
14+ import kotlin.io.path.createTempDirectory
15+ import kotlin.io.path.deleteRecursively
916
1017/* *
1118 * Tests for the MNIST loader.
@@ -22,27 +29,27 @@ class MNISTLoaderTest {
2229 try {
2330 // Create a loader with the temporary directory
2431 val loader = MNISTLoaderFactory .create(tempDir.absolutePath)
25-
32+
2633 // Load the training data
2734 val dataset = loader.loadTrainingData()
28-
35+
2936 // Verify the dataset
3037 assertNotNull(dataset)
3138 assertTrue(dataset.size > 0 )
3239 assertEquals(60000 , dataset.size) // MNIST training set has 60,000 images
33-
40+
3441 // Verify the first image
3542 val firstImage = dataset.images[0 ]
3643 assertNotNull(firstImage)
3744 assertEquals(MNISTConstants .IMAGE_PIXELS , firstImage.image.size)
3845 assertTrue(firstImage.label >= 0 && firstImage.label <= 9 )
39-
46+
4047 // Verify that the cache files were created
4148 val trainImagesFile = File (tempDir, MNISTConstants .TRAIN_IMAGES_FILENAME .removeSuffix(" .gz" ))
4249 val trainLabelsFile = File (tempDir, MNISTConstants .TRAIN_LABELS_FILENAME .removeSuffix(" .gz" ))
4350 assertTrue(trainImagesFile.exists())
4451 assertTrue(trainLabelsFile.exists())
45-
52+
4653 // Load the data again to test caching
4754 val cachedDataset = loader.loadTrainingData()
4855 assertEquals(dataset.size, cachedDataset.size)
@@ -55,40 +62,42 @@ class MNISTLoaderTest {
5562 /* *
5663 * Tests loading the MNIST test dataset.
5764 */
65+ @OptIn(ExperimentalPathApi ::class )
5866 @Test
5967 fun testLoadTestData () = runBlocking {
6068 // Create a temporary directory for caching
61- val tempDir = createTempDir(" mnist-test" )
69+ val tempDirPath = createTempDirectory(" mnist-test" )
70+ val tempDir = tempDirPath.absolute().toString()
6271 try {
6372 // Create a loader with the temporary directory
64- val loader = MNISTLoaderFactory .create(tempDir.absolutePath )
65-
73+ val loader = MNISTLoaderFactory .create(tempDir)
74+
6675 // Load the test data
6776 val dataset = loader.loadTestData()
68-
77+
6978 // Verify the dataset
7079 assertNotNull(dataset)
7180 assertTrue(dataset.size > 0 )
7281 assertEquals(10000 , dataset.size) // MNIST test set has 10,000 images
73-
82+
7483 // Verify the first image
7584 val firstImage = dataset.images[0 ]
7685 assertNotNull(firstImage)
7786 assertEquals(MNISTConstants .IMAGE_PIXELS , firstImage.image.size)
7887 assertTrue(firstImage.label >= 0 && firstImage.label <= 9 )
79-
88+
8089 // Verify that the cache files were created
8190 val testImagesFile = File (tempDir, MNISTConstants .TEST_IMAGES_FILENAME .removeSuffix(" .gz" ))
8291 val testLabelsFile = File (tempDir, MNISTConstants .TEST_LABELS_FILENAME .removeSuffix(" .gz" ))
8392 assertTrue(testImagesFile.exists())
8493 assertTrue(testLabelsFile.exists())
85-
94+
8695 // Load the data again to test caching
8796 val cachedDataset = loader.loadTestData()
8897 assertEquals(dataset.size, cachedDataset.size)
8998 } finally {
9099 // Clean up
91- tempDir .deleteRecursively()
100+ tempDirPath .deleteRecursively()
92101 }
93102 }
94103
@@ -99,13 +108,13 @@ class MNISTLoaderTest {
99108 fun testDatasetSubset () = runBlocking {
100109 // Create a loader with the default configuration
101110 val loader = MNISTLoaderFactory .create()
102-
111+
103112 // Load the training data
104113 val dataset = loader.loadTrainingData()
105-
114+
106115 // Create a subset
107116 val subset = dataset.subset(0 , 100 )
108-
117+
109118 // Verify the subset
110119 assertEquals(100 , subset.size)
111120 assertEquals(dataset.images[0 ], subset.images[0 ])
@@ -123,7 +132,7 @@ class MNISTLoaderTest {
123132 useCache = false
124133 )
125134 val loader = MNISTLoaderFactory .create(config)
126-
135+
127136 // Verify that the loader was created successfully
128137 assertNotNull(loader)
129138 }
0 commit comments