11package sk.ainet.lang.graph.utils
22
33import sk.ainet.lang.graph.*
4+ import sk.ainet.lang.tensor.Tensor
5+ import sk.ainet.lang.tensor.dsl.tensor
6+ import sk.ainet.lang.tensor.ops.plus
7+ import sk.ainet.lang.types.FP32
48import kotlin.test.Test
59import kotlin.test.assertNotNull
610import kotlin.test.assertTrue
@@ -16,19 +20,19 @@ class GraphVizExportTest {
1620 override val type : String ,
1721 override val parameters : Map <String , Any > = emptyMap()
1822 ) : BaseOperation(name, type, parameters) {
19-
23+
2024 override fun <T : sk.ainet.lang.types.DType , V > execute (inputs : List <sk.ainet.lang.tensor.Tensor <T , V >>): List <sk.ainet.lang.tensor.Tensor <T , V >> {
2125 return inputs // Pass through for testing
2226 }
23-
27+
2428 override fun validateInputs (inputs : List <TensorSpec >): ValidationResult {
2529 return ValidationResult .Valid
2630 }
27-
31+
2832 override fun inferOutputs (inputs : List <TensorSpec >): List <TensorSpec > {
2933 return inputs // Pass through for testing
3034 }
31-
35+
3236 override fun clone (newParameters : Map <String , Any >): Operation {
3337 return TestOperation (name, type, newParameters)
3438 }
@@ -37,123 +41,188 @@ class GraphVizExportTest {
3741 @Test
3842 fun testBasicGraphVizExport () {
3943 println (" [DEBUG_LOG] Testing basic GraphViz export functionality" )
40-
44+
4145 // Create a simple compute graph
4246 val graph = DefaultComputeGraph ()
43-
47+
4448 // Create test operations
4549 val inputOp = TestOperation (" input" , " input" )
4650 val processOp = TestOperation (" process" , " compute" , mapOf (" kernel_size" to 3 , " stride" to 1 ))
4751 val outputOp = TestOperation (" output" , " output" )
48-
52+
4953 // Create test nodes
5054 val inputNode = GraphNode (
5155 id = " input_node" ,
5256 operation = inputOp,
5357 inputs = emptyList(),
5458 outputs = listOf (TensorSpec (" input_out" , listOf (1 , 10 ), " FP32" ))
5559 )
56-
60+
5761 val processNode = GraphNode (
58- id = " process_node" ,
62+ id = " process_node" ,
5963 operation = processOp,
6064 inputs = listOf (TensorSpec (" process_in" , listOf (1 , 10 ), " FP32" )),
6165 outputs = listOf (TensorSpec (" process_out" , listOf (1 , 5 ), " FP32" ))
6266 )
63-
67+
6468 val outputNode = GraphNode (
6569 id = " output_node" ,
6670 operation = outputOp,
6771 inputs = listOf (TensorSpec (" output_in" , listOf (1 , 5 ), " FP32" )),
6872 outputs = listOf (TensorSpec (" output_out" , listOf (1 , 5 ), " FP32" ))
6973 )
70-
74+
7175 // Add nodes to graph
7276 graph.addNode(inputNode)
7377 graph.addNode(processNode)
7478 graph.addNode(outputNode)
75-
79+
7680 // Add edges to connect them
7781 graph.addEdge(GraphEdge (" edge1" , inputNode, processNode, 0 , 0 , inputNode.outputs.first()))
7882 graph.addEdge(GraphEdge (" edge2" , processNode, outputNode, 0 , 0 , processNode.outputs.first()))
79-
83+
8084 println (" [DEBUG_LOG] Created graph with ${graph.nodes.size} nodes and ${graph.edges.size} edges" )
81-
85+
8286 // Test full graph export
8387 val dotGraph = drawDot(graph)
8488 assertNotNull(dotGraph, " DOT graph should not be null" )
8589 assertNotNull(dotGraph.content, " DOT content should not be null" )
8690 assertTrue(dotGraph.content.isNotEmpty(), " DOT content should not be empty" )
87-
91+
8892 println (" [DEBUG_LOG] Full graph DOT export:" )
8993 println (dotGraph.content)
90-
94+
9195 // Verify DOT content structure
9296 assertTrue(dotGraph.content.contains(" digraph {" ), " Should contain digraph declaration" )
9397 assertTrue(dotGraph.content.contains(" rankdir=LR" ), " Should contain rankdir setting" )
9498 assertTrue(dotGraph.content.contains(" input_node" ), " Should contain input node" )
95- assertTrue(dotGraph.content.contains(" process_node" ), " Should contain process node" )
99+ assertTrue(dotGraph.content.contains(" process_node" ), " Should contain process node" )
96100 assertTrue(dotGraph.content.contains(" output_node" ), " Should contain output node" )
97101 assertTrue(dotGraph.content.contains(" ->" ), " Should contain edges" )
98102 assertTrue(dotGraph.content.contains(" }" ), " Should contain closing brace" )
99-
103+
100104 // Test parameters are included
101105 assertTrue(dotGraph.content.contains(" kernel_size" ), " Should include operation parameters" )
102-
106+
103107 println (" [DEBUG_LOG] Basic GraphViz export test passed" )
104108 }
105-
109+
106110 @Test
107111 fun testSubsetGraphVizExport () {
108112 println (" [DEBUG_LOG] Testing subset GraphViz export functionality" )
109-
113+
110114 // Create a compute graph
111115 val graph = DefaultComputeGraph ()
112-
116+
113117 // Create test nodes
114- val node1 = GraphNode (" node1" , TestOperation (" op1" , " type1" ), emptyList(), listOf (TensorSpec (" out1" , listOf (1 ), " FP32" )))
115- val node2 = GraphNode (" node2" , TestOperation (" op2" , " type2" ), listOf (TensorSpec (" in2" , listOf (1 ), " FP32" )), listOf (TensorSpec (" out2" , listOf (1 ), " FP32" )))
116- val node3 = GraphNode (" node3" , TestOperation (" op3" , " type3" ), listOf (TensorSpec (" in3" , listOf (1 ), " FP32" )), listOf (TensorSpec (" out3" , listOf (1 ), " FP32" )))
117-
118+ val node1 = GraphNode (
119+ " node1" ,
120+ TestOperation (" op1" , " type1" ),
121+ emptyList(),
122+ listOf (TensorSpec (" out1" , listOf (1 ), " FP32" ))
123+ )
124+ val node2 = GraphNode (
125+ " node2" ,
126+ TestOperation (" op2" , " type2" ),
127+ listOf (TensorSpec (" in2" , listOf (1 ), " FP32" )),
128+ listOf (TensorSpec (" out2" , listOf (1 ), " FP32" ))
129+ )
130+ val node3 = GraphNode (
131+ " node3" ,
132+ TestOperation (" op3" , " type3" ),
133+ listOf (TensorSpec (" in3" , listOf (1 ), " FP32" )),
134+ listOf (TensorSpec (" out3" , listOf (1 ), " FP32" ))
135+ )
136+
118137 graph.addNode(node1)
119138 graph.addNode(node2)
120139 graph.addNode(node3)
121-
140+
122141 graph.addEdge(GraphEdge (" edge1" , node1, node2, 0 , 0 , node1.outputs.first()))
123142 graph.addEdge(GraphEdge (" edge2" , node2, node3, 0 , 0 , node2.outputs.first()))
124-
143+
125144 // Test subset export (only from node3 backward)
126145 val dotGraphSubset = drawDot(graph, listOf (node3))
127146 assertNotNull(dotGraphSubset, " Subset DOT graph should not be null" )
128147 assertTrue(dotGraphSubset.content.isNotEmpty(), " Subset DOT content should not be empty" )
129-
148+
130149 println (" [DEBUG_LOG] Subset DOT export:" )
131150 println (dotGraphSubset.content)
132-
151+
133152 // All nodes should be included since node3 depends on all previous nodes
134153 assertTrue(dotGraphSubset.content.contains(" node1" ), " Should contain node1 in subset" )
135154 assertTrue(dotGraphSubset.content.contains(" node2" ), " Should contain node2 in subset" )
136155 assertTrue(dotGraphSubset.content.contains(" node3" ), " Should contain node3 in subset" )
137-
156+
138157 println (" [DEBUG_LOG] Subset GraphViz export test passed" )
139158 }
140-
159+
141160 @Test
142161 fun testDifferentRankDirections () {
143162 println (" [DEBUG_LOG] Testing different rank directions" )
144-
163+
145164 val graph = DefaultComputeGraph ()
146- val node = GraphNode (" test" , TestOperation (" test" , " test" ), emptyList(), listOf (TensorSpec (" out" , listOf (1 ), " FP32" )))
165+ val node =
166+ GraphNode (" test" , TestOperation (" test" , " test" ), emptyList(), listOf (TensorSpec (" out" , listOf (1 ), " FP32" )))
147167 graph.addNode(node)
148-
168+
149169 // Test LR direction (default)
150170 val dotLR = drawDot(graph, " LR" )
151171 assertTrue(dotLR.content.contains(" rankdir=LR" ), " Should set LR rank direction" )
152-
172+
153173 // Test TB direction
154174 val dotTB = drawDot(graph, " TB" )
155175 assertTrue(dotTB.content.contains(" rankdir=TB" ), " Should set TB rank direction" )
156-
176+
157177 println (" [DEBUG_LOG] Rank direction test passed" )
158178 }
179+ @Test
180+ fun testSimpleExpressionToGraphviz () {
181+ println (" [DEBUG_LOG] Testing simple expression to GraphViz" )
182+
183+ // Create tensors exactly as specified in the issue
184+ val a = tensor<FP32 , Float > {
185+ shape(1 ) { ones() }
186+ }
187+
188+ val b = tensor<FP32 , Float > {
189+ shape(1 ) { ones() }
190+ }
191+
192+ // Execute graph operation exactly as specified
193+ val result: GraphExecutionResult <FP32 , Float , Tensor <FP32 , Float >> = exec<FP32 , Float , Tensor <FP32 , Float >> {
194+ a + b
195+ }
196+
197+ println (" [DEBUG_LOG] Graph has ${result.graph.nodes.size} nodes" )
198+ for (node in result.graph.nodes) {
199+ println (" [DEBUG_LOG] Node: ${node.id} , Operation: ${node.operationName} , Type: ${node.operation.type} " )
200+ }
201+
202+ // Test TB direction
203+ val dotTB = drawDot(result.graph, " TB" )
204+ println (" [DEBUG_LOG] Generated DOT content:" )
205+ println (dotTB.content)
206+
207+ // Verify that we have exactly 3 nodes as expected
208+ assertEquals(3 , result.graph.nodes.size, " Should have exactly 3 nodes" )
209+
210+ // Verify the DOT content contains 3 node definitions
211+ val nodeDefinitions = dotTB.content.lines().filter { it.trim().contains(" [label=" ) }
212+ assertEquals(3 , nodeDefinitions.size, " DOT should contain 3 node definitions" )
213+
214+ // Verify we have 2 input nodes (rectangles with values) and 1 addition operation node
215+ val nodesByType = result.graph.nodes.groupBy { it.operation.type }
216+ assertTrue(nodesByType.containsKey(" input" ), " Should contain input operations" )
217+ assertTrue(nodesByType.containsKey(" math" ), " Should contain math operations" )
218+ assertEquals(2 , nodesByType[" input" ]!! .size, " Should have exactly 2 input nodes" )
219+ assertEquals(1 , nodesByType[" math" ]!! .size, " Should have exactly 1 math operation node" )
220+
221+ // Verify the addition operation name
222+ val addNode = nodesByType[" math" ]!! .first()
223+ assertEquals(" add" , addNode.operationName, " Math operation should be 'add'" )
224+
225+ println (" [DEBUG_LOG] Simple expression GraphViz test completed" )
226+ }
227+
159228}
0 commit comments