Skip to content

Commit bafa4e6

Browse files
committed
Add support for "describe" fucntion for DNN.
1 parent da16f56 commit bafa4e6

15 files changed

Lines changed: 1027 additions & 206 deletions

File tree

settings.gradle.kts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ dependencyResolutionManagement {
1313
}
1414
}
1515

16-
rootProject.name = "skainet"
16+
rootProject.name = "SKaiNET"
1717

1818
include("skainet-core:skainet-tensors-api")
1919
include("skainet-core:skainet-tensors")
2020
include("skainet-core:skainet-performance")
21-
include("skainet-core:skainet-reflection")
21+
include("skainet-core:skainet-core-reflection")
2222
include("skainet-nn:skainet-nn-api")
2323
include("skainet-nn:skainet-nn-relection")
2424
include("skainet-data:skainet-data-api")

skainet-nn/skainet-nn-reflection/build.gradle.kts renamed to skainet-core/skainet-core-reflection/build.gradle.kts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ kotlin {
3737
val commonMain by getting {
3838
dependencies {
3939
implementation(project(":skainet-core:skainet-tensors-api"))
40-
implementation(project(":skainet-core:skainet-tensors-api"))
40+
implementation(project(":skainet-core:skainet-tensors"))
41+
implementation(project(":skainet-nn:skainet-nn-api"))
42+
4143
}
4244
}
4345

@@ -49,7 +51,7 @@ kotlin {
4951
}
5052

5153
android {
52-
namespace = "sk.ainet.core.api"
54+
namespace = "sk.ainet.core.reflection"
5355
compileSdk = libs.versions.android.compileSdk.get().toInt()
5456

5557
defaultConfig {
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
POM_ARTIFACT_ID=nn-api
1+
POM_ARTIFACT_ID=core-reflection
22
POM_NAME=skainet neural network API
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package sk.ainet.nn.reflection
2+
3+
import sk.ainet.core.tensor.DType
4+
import sk.ainet.core.tensor.Shape
5+
import sk.ainet.core.tensor.Tensor
6+
import sk.ainet.nn.Module
7+
import sk.ainet.nn.reflection.table.table
8+
import sk.ainet.nn.topology.ModuleParameters
9+
import sk.ainet.nn.topology.by
10+
11+
12+
public data class NodeSummary(val name: String, val input: Shape, val output: Shape, val params: Long)
13+
14+
public class Summary<T : DType, V> {
15+
16+
public val nodes: MutableList<NodeSummary> = mutableListOf<NodeSummary>()
17+
18+
private fun countParameters(module: Module<T, V>): Long {
19+
var params = 0L
20+
21+
if (module is ModuleParameters<*, *>) {
22+
module.params.forEach { param ->
23+
params += param.value.shape.volume
24+
}
25+
}
26+
27+
return params
28+
}
29+
30+
private fun nodeSummary(module: Module<T, V>, input: Shape, output: Shape): NodeSummary {
31+
val params = countParameters(module)
32+
33+
return NodeSummary(
34+
module.name,
35+
input,
36+
output,
37+
params
38+
)
39+
}
40+
41+
42+
private fun traverseModules(module: Module<T, V>, currentInput: Shape): List<NodeSummary> {
43+
val result = mutableListOf<NodeSummary>()
44+
45+
// For leaf modules (modules with parameters), create a summary
46+
if (module is ModuleParameters<*, *> && module.params.isNotEmpty()) {
47+
// We'll use the current input shape as both input and output for now
48+
// In a real implementation, this would require actual forward pass
49+
val summary = nodeSummary(module, currentInput, currentInput)
50+
result.add(summary)
51+
}
52+
53+
// Recursively traverse nested modules
54+
module.modules.forEach { subModule ->
55+
result.addAll(traverseModules(subModule, currentInput))
56+
}
57+
58+
return result
59+
}
60+
61+
public fun summary(model: Module<T, V>, input: Shape, batch_size: Int = -1): List<NodeSummary> {
62+
nodes.clear()
63+
val summaries = traverseModules(model, input)
64+
nodes.addAll(summaries)
65+
return summaries
66+
}
67+
68+
public fun printSummary(nodes: List<NodeSummary>): String =
69+
table {
70+
cellStyle {
71+
border = true
72+
}
73+
header {
74+
row {
75+
cell("Layer (type)")
76+
cell("Output Shape")
77+
cell("Param #")
78+
}
79+
}
80+
nodes.forEach { node ->
81+
row {
82+
cell(node.name)
83+
cell(node.output.toString())
84+
cell(node.params)
85+
}
86+
}
87+
}.toString()
88+
}
89+
90+
public fun <T : DType, V> Module<T, V>.describe(input: Shape, batch_size: Int = -1): String {
91+
val summary = Summary<T, V>()
92+
val nodes = summary.summary(this, input, batch_size)
93+
return summary.printSummary(nodes)
94+
}

skainet-nn/skainet-nn-reflection/src/commonMain/kotlin/sk/ainet/nn/reflection/table/TableBuilder.kt renamed to skainet-core/skainet-core-reflection/src/commonMain/kotlin/sk/ainet/utils/table/TableBuilder.kt

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ public fun table(block: Table.() -> Unit): Table {
88
// The Table DSL class
99
public class Table {
1010
// A cell style configuration object
11-
public val cellStyle = CellStyle()
11+
public val cellStyle: CellStyle = CellStyle()
1212

1313
// Optional header section
1414
public var header: Header? = null
1515

1616
// List of body rows
17-
public val rows = mutableListOf<Row>()
17+
public val rows: MutableList<Row> = mutableListOf<Row>()
1818

1919
// DSL function to configure the cell style
2020
public fun cellStyle(block: CellStyle.() -> Unit) {
@@ -97,28 +97,28 @@ public class Table {
9797
}
9898

9999
// A simple header container allowing multiple header rows.
100-
class Header {
101-
val rows = mutableListOf<Row>()
100+
public class Header {
101+
public val rows: MutableList<Row> = mutableListOf<Row>()
102102

103-
fun row(block: Row.() -> Unit) {
103+
public fun row(block: Row.() -> Unit): Unit {
104104
rows.add(Row().apply(block))
105105
}
106106
}
107107

108108
// Represents a row in the table.
109-
class Row {
110-
val cells = mutableListOf<Cell>()
109+
public class Row {
110+
public val cells: MutableList<Cell> = mutableListOf<Cell>()
111111

112112
// Adds a cell to the row.
113-
fun cell(value: Any?) {
113+
public fun cell(value: Any?): Unit {
114114
cells.add(Cell(value?.toString() ?: ""))
115115
}
116116
}
117117

118118
// Represents a cell containing text.
119-
class Cell(val content: String)
119+
public class Cell(public val content: String)
120120

121121
// A configuration class for cell style options.
122-
class CellStyle {
123-
var border: Boolean = false
122+
public class CellStyle {
123+
public var border: Boolean = false
124124
}

0 commit comments

Comments
 (0)