Skip to content

Commit e457975

Browse files
committed
Add initial implementation for datasets
Relate-To: #112, #113
1 parent df90721 commit e457975

8 files changed

Lines changed: 951 additions & 0 deletions

File tree

dataset_arch.md

Lines changed: 525 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import org.jetbrains.kotlin.gradle.ExperimentalKotlinGradlePluginApi
2+
import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl
3+
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
4+
5+
plugins {
6+
alias(libs.plugins.kotlinMultiplatform)
7+
alias(libs.plugins.androidLibrary)
8+
alias(libs.plugins.vanniktech.mavenPublish)
9+
}
10+
11+
kotlin {
12+
explicitApi()
13+
14+
androidTarget {
15+
@OptIn(ExperimentalKotlinGradlePluginApi::class)
16+
compilerOptions {
17+
jvmTarget.set(JvmTarget.JVM_11)
18+
}
19+
}
20+
21+
iosArm64()
22+
iosSimulatorArm64()
23+
macosArm64 ()
24+
linuxX64 ()
25+
linuxArm64 ()
26+
27+
jvm()
28+
29+
@OptIn(ExperimentalWasmDsl::class)
30+
wasmJs {
31+
browser()
32+
binaries.executable()
33+
}
34+
35+
sourceSets {
36+
val commonMain by getting {
37+
dependencies {
38+
implementation(project(":skainet-core:skainet-tensors-api"))
39+
implementation(project(":skainet-core:skainet-tensors"))
40+
}
41+
}
42+
43+
commonTest.dependencies {
44+
implementation(libs.kotlin.test)
45+
implementation(project(":skainet-core:skainet-performance"))
46+
}
47+
}
48+
}
49+
50+
android {
51+
namespace = "sk.ainet.core.api"
52+
compileSdk = libs.versions.android.compileSdk.get().toInt()
53+
54+
defaultConfig {
55+
minSdk = libs.versions.android.minSdk.get().toInt()
56+
}
57+
compileOptions {
58+
sourceCompatibility = JavaVersion.VERSION_11
59+
targetCompatibility = JavaVersion.VERSION_11
60+
}
61+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
POM_ARTIFACT_ID=data-api
2+
POM_NAME=skainet datasets API
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package sk.ainet.data
2+
3+
import sk.ainet.core.tensor.DType
4+
import sk.ainet.core.tensor.Tensor
5+
6+
public data class DataBatch<T : DType, V>(val x: Array<Tensor<T, V>>, val y: Tensor<T, V>) {
7+
override fun equals(other: Any?): Boolean {
8+
if (this === other) return true
9+
if (other == null || this::class != other::class) return false
10+
11+
other as DataBatch<*, *>
12+
13+
if (!x.contentEquals(other.x)) return false
14+
if (y != other.y) return false
15+
16+
return true
17+
}
18+
19+
override fun hashCode(): Int {
20+
var result = x.contentHashCode()
21+
result = 31 * result + y.hashCode()
22+
return result
23+
}
24+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package sk.ainet.data
2+
3+
import sk.ainet.core.tensor.DType
4+
import kotlin.math.min
5+
6+
7+
/** Just abstract Dataset. */
8+
public abstract class Dataset<T, Y> {
9+
/** Splits datasets on two sub-datasets according [splitRatio].*/
10+
public abstract fun split(splitRatio: Double): Pair<Dataset<T, Y>, Dataset<T, Y>>
11+
12+
/** Returns amount of data rows. */
13+
public abstract val xSize: Int
14+
15+
/** Returns row by index [idx]. */
16+
public abstract fun getX(idx: Int): T
17+
18+
/** Returns label as [Int] by index [idx]. */
19+
public abstract fun getY(idx: Int): Y
20+
21+
/** Shuffles the dataset. */
22+
public abstract fun shuffle(): Dataset<T, Y>
23+
24+
/**
25+
* An iterator over a [Dataset].
26+
*/
27+
public inner class BatchIterator<T : DType, V> internal constructor(
28+
private val batchSize: Int
29+
) : Iterator<DataBatch<T, V>> {
30+
31+
private var batchStart = 0
32+
33+
override fun hasNext(): Boolean = batchStart < xSize
34+
35+
override fun next(): DataBatch<T, V> {
36+
val batchLength = min(batchSize, xSize - batchStart)
37+
val batch = createDataBatch<T, V>(batchStart, batchLength)
38+
batchStart += batchSize
39+
return batch
40+
}
41+
}
42+
43+
/** Creates data batch that starts from [batchStart] with length [batchLength]. */
44+
protected abstract fun <T : DType, V> createDataBatch(batchStart: Int, batchLength: Int): DataBatch<T, V>
45+
46+
47+
/** Returns [BatchIterator] with fixed [batchSize]. */
48+
public fun <T : DType, V> batchIterator(batchSize: Int): BatchIterator<T, V> {
49+
return BatchIterator(batchSize)
50+
}
51+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import org.jetbrains.kotlin.gradle.ExperimentalKotlinGradlePluginApi
2+
import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl
3+
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
4+
5+
plugins {
6+
alias(libs.plugins.kotlinMultiplatform)
7+
alias(libs.plugins.androidLibrary)
8+
alias(libs.plugins.vanniktech.mavenPublish)
9+
}
10+
11+
kotlin {
12+
explicitApi()
13+
14+
androidTarget {
15+
@OptIn(ExperimentalKotlinGradlePluginApi::class)
16+
compilerOptions {
17+
jvmTarget.set(JvmTarget.JVM_11)
18+
}
19+
}
20+
21+
iosArm64()
22+
iosSimulatorArm64()
23+
macosArm64 ()
24+
linuxX64 ()
25+
linuxArm64 ()
26+
27+
jvm()
28+
29+
@OptIn(ExperimentalWasmDsl::class)
30+
wasmJs {
31+
browser()
32+
binaries.executable()
33+
}
34+
35+
sourceSets {
36+
val commonMain by getting {
37+
dependencies {
38+
implementation(project(":skainet-core:skainet-tensors-api"))
39+
implementation(project(":skainet-core:skainet-tensors-api"))
40+
}
41+
}
42+
43+
commonTest.dependencies {
44+
implementation(libs.kotlin.test)
45+
implementation(project(":skainet-core:skainet-performance"))
46+
}
47+
}
48+
}
49+
50+
android {
51+
namespace = "sk.ainet.core.api"
52+
compileSdk = libs.versions.android.compileSdk.get().toInt()
53+
54+
defaultConfig {
55+
minSdk = libs.versions.android.minSdk.get().toInt()
56+
}
57+
compileOptions {
58+
sourceCompatibility = JavaVersion.VERSION_11
59+
targetCompatibility = JavaVersion.VERSION_11
60+
}
61+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
POM_ARTIFACT_ID=data-basic
2+
POM_NAME=skainet neural basic datasets

0 commit comments

Comments
 (0)