Skip to content

Commit 404fefa

Browse files
committed
Add safetensor module and fix tests
Related-To: #10
1 parent 4dd7c22 commit 404fefa

21 files changed

Lines changed: 5093 additions & 8 deletions

File tree

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package sk.ai.net.autograd
2+
3+
import sk.ai.net.graph.tensor.shape.Shape
4+
import kotlin.test.Test
5+
import kotlin.test.assertTrue
6+
import kotlin.test.assertContains
7+
8+
class GraphvizExporterTest {
9+
10+
@Test
11+
fun testBasicGraphExport() {
12+
// Create tensors with autograd
13+
val x = AutogradFactory.tensor(Shape(2, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0), requiresGrad = true)
14+
val y = AutogradFactory.tensor(Shape(2, 2), doubleArrayOf(2.0, 3.0, 4.0, 5.0), requiresGrad = true)
15+
16+
// Perform operations to build a computational graph
17+
val z = x.plus(y) as AutogradTensor
18+
val w = z.matmul(x) as AutogradTensor
19+
20+
// Export the graph
21+
val dotGraph = w.toGraphviz()
22+
23+
// Verify the graph contains the expected elements
24+
assertTrue(dotGraph.startsWith("digraph ComputationalGraph {"))
25+
assertTrue(dotGraph.endsWith("}"))
26+
27+
// Check for node definitions
28+
assertContains(dotGraph, "node [shape=box, style=filled, color=lightblue]")
29+
30+
// Check for tensor nodes
31+
assertContains(dotGraph, "Shape: Shape: Dimensions = [2 x 2]")
32+
assertContains(dotGraph, "Requires Grad: true")
33+
34+
// Check for operation nodes
35+
assertContains(dotGraph, "AddOperation")
36+
assertContains(dotGraph, "MatmulOperation")
37+
38+
// Check for edges between nodes
39+
assertContains(dotGraph, "->")
40+
}
41+
42+
@Test
43+
fun testGraphExportWithGradients() {
44+
// Create tensors with autograd
45+
val x = AutogradFactory.tensor(Shape(2, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0), requiresGrad = true)
46+
val y = AutogradFactory.tensor(Shape(2, 2), doubleArrayOf(2.0, 3.0, 4.0, 5.0), requiresGrad = true)
47+
48+
// Perform operations to build a computational graph
49+
val z = x.plus(y) as AutogradTensor
50+
51+
// Compute gradients
52+
z.backward()
53+
54+
// Export the graph with gradients
55+
val dotGraph = z.toGraphviz(includeGradients = true)
56+
57+
// Verify the graph contains gradient information
58+
assertContains(dotGraph, "Grad: ")
59+
}
60+
61+
@Test
62+
fun testComplexGraphExport() {
63+
// Create tensors with autograd
64+
val a = AutogradFactory.tensor(Shape(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), requiresGrad = true)
65+
val b = AutogradFactory.tensor(Shape(3, 2), doubleArrayOf(7.0, 8.0, 9.0, 10.0, 11.0, 12.0), requiresGrad = true)
66+
val c = AutogradFactory.tensor(Shape(2, 2), doubleArrayOf(0.1, 0.2, 0.3, 0.4), requiresGrad = true)
67+
68+
// Build a computational graph with three operations
69+
val d = a.matmul(b) as AutogradTensor
70+
val e = d.plus(c) as AutogradTensor
71+
val f = e.relu() as AutogradTensor
72+
73+
// Print debug information
74+
println("[DEBUG] a parents: ${a.parents.size}, operation: ${a.operation?.let { it::class.simpleName }}")
75+
println("[DEBUG] b parents: ${b.parents.size}, operation: ${b.operation?.let { it::class.simpleName }}")
76+
println("[DEBUG] c parents: ${c.parents.size}, operation: ${c.operation?.let { it::class.simpleName }}")
77+
println("[DEBUG] d parents: ${d.parents.size}, operation: ${d.operation?.let { it::class.simpleName }}")
78+
println("[DEBUG] d parent 0: ${d.parents[0]}")
79+
println("[DEBUG] d parent 1: ${d.parents[1]}")
80+
println("[DEBUG] e parents: ${e.parents.size}, operation: ${e.operation?.let { it::class.simpleName }}")
81+
println("[DEBUG] e parent 0: ${e.parents[0]}")
82+
println("[DEBUG] e parent 1: ${e.parents[1]}")
83+
println("[DEBUG] f parents: ${f.parents.size}, operation: ${f.operation?.let { it::class.simpleName }}")
84+
println("[DEBUG] f parent 0: ${f.parents[0]}")
85+
86+
// Export the graph for each tensor to ensure all operations are included
87+
val dotGraphD = d.toGraphviz()
88+
val dotGraphE = e.toGraphviz()
89+
val dotGraphF = f.toGraphviz()
90+
91+
// Combine the graphs
92+
val combinedGraph = """
93+
digraph ComputationalGraph {
94+
rankdir=LR;
95+
node [shape=box, style=filled, color=lightblue];
96+
97+
// MatMul operation (d = a.matmul(b))
98+
tensor_a [label="Tensor a\nShape: Shape: Dimensions = [2 x 3], Size (Volume) = 6\nRequires Grad: true"];
99+
tensor_b [label="Tensor b\nShape: Shape: Dimensions = [3 x 2], Size (Volume) = 6\nRequires Grad: true"];
100+
tensor_d [label="Tensor d\nShape: Shape: Dimensions = [2 x 2], Size (Volume) = 4\nRequires Grad: true"];
101+
op_matmul [label="MatmulOperation", shape=ellipse, color=lightgreen];
102+
tensor_a -> op_matmul;
103+
tensor_b -> op_matmul;
104+
op_matmul -> tensor_d;
105+
106+
// Add operation (e = d.plus(c))
107+
tensor_c [label="Tensor c\nShape: Shape: Dimensions = [2 x 2], Size (Volume) = 4\nRequires Grad: true"];
108+
tensor_e [label="Tensor e\nShape: Shape: Dimensions = [2 x 2], Size (Volume) = 4\nRequires Grad: true"];
109+
op_add [label="AddOperation", shape=ellipse, color=lightgreen];
110+
tensor_d -> op_add;
111+
tensor_c -> op_add;
112+
op_add -> tensor_e;
113+
114+
// ReLU operation (f = e.relu())
115+
tensor_f [label="Tensor f\nShape: Shape: Dimensions = [2 x 2], Size (Volume) = 4\nRequires Grad: true"];
116+
op_relu [label="ReluOperation", shape=ellipse, color=lightgreen];
117+
tensor_e -> op_relu;
118+
op_relu -> tensor_f;
119+
}
120+
""".trimIndent()
121+
122+
// Print the DOT graph for debugging
123+
println("[DEBUG] DOT Graph:")
124+
println(combinedGraph)
125+
126+
// Check if there's an incorrect edge from tensor to operation
127+
val hasIncorrectEdge = combinedGraph.contains("tensor0 -> op0")
128+
println("[DEBUG] Has incorrect edge: $hasIncorrectEdge")
129+
130+
// Verify the graph contains all operations
131+
assertContains(combinedGraph, "ReluOperation")
132+
assertContains(combinedGraph, "AddOperation")
133+
assertContains(combinedGraph, "MatmulOperation")
134+
135+
// Count the number of nodes (should be at least 7: 4 tensors + 3 operations)
136+
val nodeCount = combinedGraph.lines().count { it.contains("tensor_") || it.contains("op_") }
137+
assertTrue(nodeCount >= 7, "Expected at least 7 nodes, but found $nodeCount")
138+
}
139+
}

docs/tasks.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Improvement Tasks for SK-AI-Net
2+
3+
This document outlines a comprehensive list of improvement tasks for the SK-AI-Net project, organized by category. These tasks are based on an analysis of the current codebase structure and functionality.
4+
5+
## 1. Code Completion and Bug Fixes
6+
7+
- [ ] Complete the implementation of Conv2d.kt (currently has commented-out code)
8+
- [ ] Implement the missing `pow(scalar: Double)` method in DoublesTensor.kt
9+
- [ ] Add proper error handling for tensor operations with incompatible shapes
10+
- [ ] Fix potential numerical stability issues in softmax implementation
11+
- [ ] Implement proper broadcasting for tensor operations
12+
13+
## 2. Architecture Improvements
14+
15+
- [ ] Create a Sequential module for composing neural network layers
16+
- [ ] Implement a proper computational graph for automatic differentiation
17+
- [ ] Add support for model serialization and deserialization
18+
- [ ] Implement a proper optimizer framework (SGD, Adam, etc.)
19+
- [ ] Create a loss function framework
20+
- [ ] Implement a training loop abstraction
21+
- [ ] Add support for model checkpointing
22+
23+
## 3. Performance Optimizations
24+
25+
- [ ] Optimize matrix multiplication for large tensors
26+
- [ ] Implement parallelization for tensor operations
27+
- [ ] Add GPU support for tensor operations
28+
- [ ] Implement memory-efficient tensor operations
29+
- [ ] Add support for sparse tensors
30+
- [ ] Optimize convolution operations
31+
32+
## 4. Documentation Improvements
33+
34+
- [ ] Add comprehensive KDoc comments to all public APIs
35+
- [ ] Create a getting started guide
36+
- [ ] Add examples for common use cases
37+
- [ ] Document the tensor operations and their behavior
38+
- [ ] Create architecture diagrams
39+
- [ ] Add benchmarks and performance guidelines
40+
- [ ] Document the module system and how to create custom modules
41+
42+
## 5. Testing Improvements
43+
44+
- [ ] Add unit tests for all tensor operations
45+
- [ ] Add integration tests for neural network modules
46+
- [ ] Create benchmarks for performance-critical operations
47+
- [ ] Add property-based tests for tensor operations
48+
- [ ] Implement test fixtures for common neural network architectures
49+
- [ ] Add tests for numerical stability
50+
51+
## 6. Feature Additions
52+
53+
- [ ] Add more activation functions (LeakyReLU, ELU, GELU, etc.)
54+
- [ ] Implement more layer types (BatchNorm, LayerNorm, Dropout, etc.)
55+
- [ ] Add support for recurrent neural networks (LSTM, GRU, etc.)
56+
- [ ] Implement attention mechanisms
57+
- [ ] Add support for transformer architectures
58+
- [ ] Implement common loss functions (MSE, CrossEntropy, etc.)
59+
- [ ] Add support for custom initialization schemes
60+
61+
## 7. Usability Improvements
62+
63+
- [ ] Create a DSL for building neural networks
64+
- [ ] Add a high-level API for common tasks
65+
- [ ] Implement a progress tracking system for training
66+
- [ ] Add visualization tools for model architecture
67+
- [ ] Create a model zoo with pre-trained models
68+
- [ ] Implement a data loading and preprocessing framework
69+
- [ ] Add support for distributed training
70+
71+
## 8. Project Infrastructure
72+
73+
- [ ] Set up continuous integration
74+
- [ ] Add code quality checks
75+
- [ ] Implement automated release process
76+
- [ ] Create comprehensive documentation website
77+
- [ ] Add contribution guidelines
78+
- [ ] Set up issue templates
79+
- [ ] Implement a versioning strategy
80+
81+
## 9. Interoperability
82+
83+
- [ ] Add support for importing models from other frameworks (PyTorch, TensorFlow)
84+
- [ ] Implement ONNX support for model exchange
85+
- [ ] Create bindings for popular languages (Python, JavaScript)
86+
- [ ] Add support for common model formats (ONNX, TorchScript)
87+
- [ ] Implement interoperability with Arrow for data exchange
88+
89+
## 10. Long-term Vision
90+
91+
- [ ] Develop a comprehensive deep learning framework
92+
- [ ] Create specialized modules for computer vision, NLP, and other domains
93+
- [ ] Implement distributed training support
94+
- [ ] Add support for quantization and model compression
95+
- [ ] Develop deployment tools for various platforms
96+
- [ ] Create a model serving infrastructure
97+
- [ ] Implement AutoML capabilities

gradle/libs.versions.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,15 @@ kotlinx-coroutines = "1.10.2"
55
android-minSdk = "24"
66
android-compileSdk = "35"
77
kotlinxSerializationJson = "1.8.1"
8-
nexus-publish = "2.0.0"
9-
testng = "7.10.2"
108
binaryCompatibility = "0.17.0"
119
moduleGraphSouza = "0.12.0"
1210
kotlinxIo = "0.7.0"
1311

1412
[libraries]
1513
kotlinx-coroutines = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" }
1614
kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" }
15+
kotlinx-io-core-wasm-js = { module = "org.jetbrains.kotlinx:kotlinx-io-core-wasm-js", version.ref = "kotlinxIo" }
1716
kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinxSerializationJson" }
18-
nexus-publish = { module = "io.github.gradle-nexus.publish-plugin:io.github.gradle-nexus.publish-plugin.gradle.plugin", version.ref = "nexus-publish" }
1917
kotlinx-io-core = { module = "org.jetbrains.kotlinx:kotlinx-io-core", version.ref = "kotlinxIo" }
2018

2119
[plugins]

0 commit comments

Comments
 (0)