Skip to content

Commit 8e14fd5

Browse files
committed
Add different shape for inputs
Related-To #138
1 parent 3e6ae79 commit 8e14fd5

4 files changed

Lines changed: 187 additions & 41 deletions

File tree

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/graph/utils/graphviz.kt

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,14 @@ public fun drawDot(graph: ComputeGraph, rankdir: String = "LR"): DotGraph {
4444

4545
val labelContent = "${node.operationName} | ${node.id}"
4646

47-
dotContent.appendLine(" $nodeId [label=\"$labelContent\", shape=record];")
47+
// Assign shape based on operation type
48+
val shapeAttributes = when (node.operation.type) {
49+
"input" -> "shape=record, style=filled, fillcolor=lightblue"
50+
"math" -> "shape=circle"
51+
else -> "shape=record" // values and other operations use rectangle
52+
}
53+
54+
dotContent.appendLine(" $nodeId [label=\"$labelContent\", $shapeAttributes];")
4855

4956
// Add operation node if operation has parameters
5057
if (node.operation.parameters.isNotEmpty()) {
@@ -103,7 +110,14 @@ public fun drawDot(graph: ComputeGraph, outputNodes: List<GraphNode>, rankdir: S
103110

104111
val labelContent = "${node.operationName} | ${node.id}"
105112

106-
dotContent.appendLine(" $nodeId [label=\"$labelContent\", shape=record];")
113+
// Assign shape based on operation type
114+
val shapeAttributes = when (node.operation.type) {
115+
"input" -> "shape=record, style=filled, fillcolor=lightblue"
116+
"math" -> "shape=circle"
117+
else -> "shape=record" // values and other operations use rectangle
118+
}
119+
120+
dotContent.appendLine(" $nodeId [label=\"$labelContent\", $shapeAttributes];")
107121

108122
// Add operation node if operation has parameters
109123
if (node.operation.parameters.isNotEmpty()) {

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/GraphTensorOps.kt

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,42 @@ public class GraphTensorOps<V>(
3131
)
3232
}
3333

34+
private fun <T : DType> ensureInputNode(tensor: Tensor<T, V>, inputName: String): GraphNode {
35+
// Check if a node for this tensor already exists
36+
val existingNode = graph.nodes.find {
37+
it.operation.type == "input" && it.outputs.any { spec -> spec.name == inputName }
38+
}
39+
40+
if (existingNode != null) {
41+
return existingNode
42+
}
43+
44+
// Create new input node
45+
val inputOperation = InputOperation<T, V>()
46+
val inputNodeId = generateNodeId("input")
47+
val tensorSpec = createTensorSpec(tensor, inputName)
48+
49+
val inputNode = GraphNode(
50+
id = inputNodeId,
51+
operation = inputOperation,
52+
inputs = emptyList(),
53+
outputs = listOf(tensorSpec)
54+
)
55+
56+
graph.addNode(inputNode)
57+
return inputNode
58+
}
59+
3460
// Basic mathematical operations
3561
override fun <T : DType> add(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> {
3662
val result = baseOps.add(a, b)
3763

3864
if (executionContext.isRecording) {
65+
// Ensure input nodes exist for both tensors
66+
val inputNodeA = ensureInputNode(a, "tensor_a")
67+
val inputNodeB = ensureInputNode(b, "tensor_b")
68+
69+
// Create the addition operation node
3970
val operation = AddOperation<T, V>()
4071
val nodeId = generateNodeId("add")
4172
val inputs = listOf(
@@ -44,8 +75,12 @@ public class GraphTensorOps<V>(
4475
)
4576
val outputs = listOf(createTensorSpec(result, "output_0"))
4677

47-
val node = GraphNode(nodeId, operation, inputs, outputs)
48-
graph.addNode(node)
78+
val addNode = GraphNode(nodeId, operation, inputs, outputs)
79+
graph.addNode(addNode)
80+
81+
// Add edges from input nodes to the addition node
82+
graph.addEdge(GraphEdge("edge_a_to_add", inputNodeA, addNode, 0, 0, inputNodeA.outputs.first()))
83+
graph.addEdge(GraphEdge("edge_b_to_add", inputNodeB, addNode, 0, 1, inputNodeB.outputs.first()))
4984
}
5085

5186
return result

skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,34 @@ import sk.ainet.lang.tensor.Tensor
44
import sk.ainet.lang.types.DType
55
import sk.ainet.lang.graph.*
66

7+
/**
8+
* Input tensor operation for graph representation
9+
*/
10+
public class InputOperation<T : DType, V>(
11+
parameters: Map<String, Any> = emptyMap()
12+
) : BaseOperation("input", "input", parameters) {
13+
14+
override fun <T2 : DType, V2> execute(inputs: List<Tensor<T2, V2>>): List<Tensor<T2, V2>> {
15+
require(inputs.isEmpty()) { "Input operation should not have inputs" }
16+
throw UnsupportedOperationException("Input operations don't execute - they represent tensor values")
17+
}
18+
19+
override fun validateInputs(inputs: List<TensorSpec>): ValidationResult {
20+
if (inputs.isNotEmpty()) {
21+
return ValidationResult.Invalid(listOf("Input operation should not have inputs, got ${inputs.size}"))
22+
}
23+
return ValidationResult.Valid
24+
}
25+
26+
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
27+
require(inputs.isEmpty()) { "Input operation should not have inputs" }
28+
// This will be set by the caller with the actual tensor spec
29+
return emptyList()
30+
}
31+
32+
override fun clone(newParameters: Map<String, Any>): Operation = InputOperation<T, V>(newParameters)
33+
}
34+
735
/**
836
* Basic math operations for graph-based execution
937
*/
Lines changed: 106 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package sk.ainet.lang.graph.utils
22

33
import sk.ainet.lang.graph.*
4+
import sk.ainet.lang.tensor.Tensor
5+
import sk.ainet.lang.tensor.dsl.tensor
6+
import sk.ainet.lang.tensor.ops.plus
7+
import sk.ainet.lang.types.FP32
48
import kotlin.test.Test
59
import kotlin.test.assertNotNull
610
import kotlin.test.assertTrue
@@ -16,19 +20,19 @@ class GraphVizExportTest {
1620
override val type: String,
1721
override val parameters: Map<String, Any> = emptyMap()
1822
) : BaseOperation(name, type, parameters) {
19-
23+
2024
override fun <T : sk.ainet.lang.types.DType, V> execute(inputs: List<sk.ainet.lang.tensor.Tensor<T, V>>): List<sk.ainet.lang.tensor.Tensor<T, V>> {
2125
return inputs // Pass through for testing
2226
}
23-
27+
2428
override fun validateInputs(inputs: List<TensorSpec>): ValidationResult {
2529
return ValidationResult.Valid
2630
}
27-
31+
2832
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
2933
return inputs // Pass through for testing
3034
}
31-
35+
3236
override fun clone(newParameters: Map<String, Any>): Operation {
3337
return TestOperation(name, type, newParameters)
3438
}
@@ -37,123 +41,188 @@ class GraphVizExportTest {
3741
@Test
3842
fun testBasicGraphVizExport() {
3943
println("[DEBUG_LOG] Testing basic GraphViz export functionality")
40-
44+
4145
// Create a simple compute graph
4246
val graph = DefaultComputeGraph()
43-
47+
4448
// Create test operations
4549
val inputOp = TestOperation("input", "input")
4650
val processOp = TestOperation("process", "compute", mapOf("kernel_size" to 3, "stride" to 1))
4751
val outputOp = TestOperation("output", "output")
48-
52+
4953
// Create test nodes
5054
val inputNode = GraphNode(
5155
id = "input_node",
5256
operation = inputOp,
5357
inputs = emptyList(),
5458
outputs = listOf(TensorSpec("input_out", listOf(1, 10), "FP32"))
5559
)
56-
60+
5761
val processNode = GraphNode(
58-
id = "process_node",
62+
id = "process_node",
5963
operation = processOp,
6064
inputs = listOf(TensorSpec("process_in", listOf(1, 10), "FP32")),
6165
outputs = listOf(TensorSpec("process_out", listOf(1, 5), "FP32"))
6266
)
63-
67+
6468
val outputNode = GraphNode(
6569
id = "output_node",
6670
operation = outputOp,
6771
inputs = listOf(TensorSpec("output_in", listOf(1, 5), "FP32")),
6872
outputs = listOf(TensorSpec("output_out", listOf(1, 5), "FP32"))
6973
)
70-
74+
7175
// Add nodes to graph
7276
graph.addNode(inputNode)
7377
graph.addNode(processNode)
7478
graph.addNode(outputNode)
75-
79+
7680
// Add edges to connect them
7781
graph.addEdge(GraphEdge("edge1", inputNode, processNode, 0, 0, inputNode.outputs.first()))
7882
graph.addEdge(GraphEdge("edge2", processNode, outputNode, 0, 0, processNode.outputs.first()))
79-
83+
8084
println("[DEBUG_LOG] Created graph with ${graph.nodes.size} nodes and ${graph.edges.size} edges")
81-
85+
8286
// Test full graph export
8387
val dotGraph = drawDot(graph)
8488
assertNotNull(dotGraph, "DOT graph should not be null")
8589
assertNotNull(dotGraph.content, "DOT content should not be null")
8690
assertTrue(dotGraph.content.isNotEmpty(), "DOT content should not be empty")
87-
91+
8892
println("[DEBUG_LOG] Full graph DOT export:")
8993
println(dotGraph.content)
90-
94+
9195
// Verify DOT content structure
9296
assertTrue(dotGraph.content.contains("digraph {"), "Should contain digraph declaration")
9397
assertTrue(dotGraph.content.contains("rankdir=LR"), "Should contain rankdir setting")
9498
assertTrue(dotGraph.content.contains("input_node"), "Should contain input node")
95-
assertTrue(dotGraph.content.contains("process_node"), "Should contain process node")
99+
assertTrue(dotGraph.content.contains("process_node"), "Should contain process node")
96100
assertTrue(dotGraph.content.contains("output_node"), "Should contain output node")
97101
assertTrue(dotGraph.content.contains("->"), "Should contain edges")
98102
assertTrue(dotGraph.content.contains("}"), "Should contain closing brace")
99-
103+
100104
// Test parameters are included
101105
assertTrue(dotGraph.content.contains("kernel_size"), "Should include operation parameters")
102-
106+
103107
println("[DEBUG_LOG] Basic GraphViz export test passed")
104108
}
105-
109+
106110
@Test
107111
fun testSubsetGraphVizExport() {
108112
println("[DEBUG_LOG] Testing subset GraphViz export functionality")
109-
113+
110114
// Create a compute graph
111115
val graph = DefaultComputeGraph()
112-
116+
113117
// Create test nodes
114-
val node1 = GraphNode("node1", TestOperation("op1", "type1"), emptyList(), listOf(TensorSpec("out1", listOf(1), "FP32")))
115-
val node2 = GraphNode("node2", TestOperation("op2", "type2"), listOf(TensorSpec("in2", listOf(1), "FP32")), listOf(TensorSpec("out2", listOf(1), "FP32")))
116-
val node3 = GraphNode("node3", TestOperation("op3", "type3"), listOf(TensorSpec("in3", listOf(1), "FP32")), listOf(TensorSpec("out3", listOf(1), "FP32")))
117-
118+
val node1 = GraphNode(
119+
"node1",
120+
TestOperation("op1", "type1"),
121+
emptyList(),
122+
listOf(TensorSpec("out1", listOf(1), "FP32"))
123+
)
124+
val node2 = GraphNode(
125+
"node2",
126+
TestOperation("op2", "type2"),
127+
listOf(TensorSpec("in2", listOf(1), "FP32")),
128+
listOf(TensorSpec("out2", listOf(1), "FP32"))
129+
)
130+
val node3 = GraphNode(
131+
"node3",
132+
TestOperation("op3", "type3"),
133+
listOf(TensorSpec("in3", listOf(1), "FP32")),
134+
listOf(TensorSpec("out3", listOf(1), "FP32"))
135+
)
136+
118137
graph.addNode(node1)
119138
graph.addNode(node2)
120139
graph.addNode(node3)
121-
140+
122141
graph.addEdge(GraphEdge("edge1", node1, node2, 0, 0, node1.outputs.first()))
123142
graph.addEdge(GraphEdge("edge2", node2, node3, 0, 0, node2.outputs.first()))
124-
143+
125144
// Test subset export (only from node3 backward)
126145
val dotGraphSubset = drawDot(graph, listOf(node3))
127146
assertNotNull(dotGraphSubset, "Subset DOT graph should not be null")
128147
assertTrue(dotGraphSubset.content.isNotEmpty(), "Subset DOT content should not be empty")
129-
148+
130149
println("[DEBUG_LOG] Subset DOT export:")
131150
println(dotGraphSubset.content)
132-
151+
133152
// All nodes should be included since node3 depends on all previous nodes
134153
assertTrue(dotGraphSubset.content.contains("node1"), "Should contain node1 in subset")
135154
assertTrue(dotGraphSubset.content.contains("node2"), "Should contain node2 in subset")
136155
assertTrue(dotGraphSubset.content.contains("node3"), "Should contain node3 in subset")
137-
156+
138157
println("[DEBUG_LOG] Subset GraphViz export test passed")
139158
}
140-
159+
141160
@Test
142161
fun testDifferentRankDirections() {
143162
println("[DEBUG_LOG] Testing different rank directions")
144-
163+
145164
val graph = DefaultComputeGraph()
146-
val node = GraphNode("test", TestOperation("test", "test"), emptyList(), listOf(TensorSpec("out", listOf(1), "FP32")))
165+
val node =
166+
GraphNode("test", TestOperation("test", "test"), emptyList(), listOf(TensorSpec("out", listOf(1), "FP32")))
147167
graph.addNode(node)
148-
168+
149169
// Test LR direction (default)
150170
val dotLR = drawDot(graph, "LR")
151171
assertTrue(dotLR.content.contains("rankdir=LR"), "Should set LR rank direction")
152-
172+
153173
// Test TB direction
154174
val dotTB = drawDot(graph, "TB")
155175
assertTrue(dotTB.content.contains("rankdir=TB"), "Should set TB rank direction")
156-
176+
157177
println("[DEBUG_LOG] Rank direction test passed")
158178
}
179+
@Test
180+
fun testSimpleExpressionToGraphviz() {
181+
println("[DEBUG_LOG] Testing simple expression to GraphViz")
182+
183+
// Create tensors exactly as specified in the issue
184+
val a = tensor<FP32, Float> {
185+
shape(1) { ones() }
186+
}
187+
188+
val b = tensor<FP32, Float> {
189+
shape(1) { ones() }
190+
}
191+
192+
// Execute graph operation exactly as specified
193+
val result: GraphExecutionResult<FP32, Float, Tensor<FP32, Float>> = exec<FP32, Float, Tensor<FP32, Float>> {
194+
a + b
195+
}
196+
197+
println("[DEBUG_LOG] Graph has ${result.graph.nodes.size} nodes")
198+
for (node in result.graph.nodes) {
199+
println("[DEBUG_LOG] Node: ${node.id}, Operation: ${node.operationName}, Type: ${node.operation.type}")
200+
}
201+
202+
// Test TB direction
203+
val dotTB = drawDot(result.graph, "TB")
204+
println("[DEBUG_LOG] Generated DOT content:")
205+
println(dotTB.content)
206+
207+
// Verify that we have exactly 3 nodes as expected
208+
assertEquals(3, result.graph.nodes.size, "Should have exactly 3 nodes")
209+
210+
// Verify the DOT content contains 3 node definitions
211+
val nodeDefinitions = dotTB.content.lines().filter { it.trim().contains("[label=") }
212+
assertEquals(3, nodeDefinitions.size, "DOT should contain 3 node definitions")
213+
214+
// Verify we have 2 input nodes (rectangles with values) and 1 addition operation node
215+
val nodesByType = result.graph.nodes.groupBy { it.operation.type }
216+
assertTrue(nodesByType.containsKey("input"), "Should contain input operations")
217+
assertTrue(nodesByType.containsKey("math"), "Should contain math operations")
218+
assertEquals(2, nodesByType["input"]!!.size, "Should have exactly 2 input nodes")
219+
assertEquals(1, nodesByType["math"]!!.size, "Should have exactly 1 math operation node")
220+
221+
// Verify the addition operation name
222+
val addNode = nodesByType["math"]!!.first()
223+
assertEquals("add", addNode.operationName, "Math operation should be 'add'")
224+
225+
println("[DEBUG_LOG] Simple expression GraphViz test completed")
226+
}
227+
159228
}

0 commit comments

Comments
 (0)