1+ package sk.ainet.data
2+
3+ import sk.ainet.core.tensor.DType
4+ import kotlin.math.min
5+
6+
7+ /* * Just abstract Dataset. */
8+ public abstract class Dataset <T , Y > {
9+ /* * Splits datasets on two sub-datasets according [splitRatio].*/
10+ public abstract fun split (splitRatio : Double ): Pair <Dataset <T , Y >, Dataset<T, Y>>
11+
12+ /* * Returns amount of data rows. */
13+ public abstract val xSize: Int
14+
15+ /* * Returns row by index [idx]. */
16+ public abstract fun getX (idx : Int ): T
17+
18+ /* * Returns label as [Int] by index [idx]. */
19+ public abstract fun getY (idx : Int ): Y
20+
21+ /* * Shuffles the dataset. */
22+ public abstract fun shuffle (): Dataset <T , Y >
23+
24+ /* *
25+ * An iterator over a [Dataset].
26+ */
27+ public inner class BatchIterator <T : DType , V > internal constructor(
28+ private val batchSize : Int
29+ ) : Iterator<DataBatch<T, V>> {
30+
31+ private var batchStart = 0
32+
33+ override fun hasNext (): Boolean = batchStart < xSize
34+
35+ override fun next (): DataBatch <T , V > {
36+ val batchLength = min(batchSize, xSize - batchStart)
37+ val batch = createDataBatch<T , V >(batchStart, batchLength)
38+ batchStart + = batchSize
39+ return batch
40+ }
41+ }
42+
43+ /* * Creates data batch that starts from [batchStart] with length [batchLength]. */
44+ protected abstract fun <T : DType , V > createDataBatch (batchStart : Int , batchLength : Int ): DataBatch <T , V >
45+
46+
47+ /* * Returns [BatchIterator] with fixed [batchSize]. */
48+ public fun <T : DType , V > batchIterator (batchSize : Int ): BatchIterator <T , V > {
49+ return BatchIterator (batchSize)
50+ }
51+ }
0 commit comments