Skip to content

Commit c0c2e57

Browse files
Added graph for pecs in Shape Classification
1 parent afcd50d commit c0c2e57

11 files changed

Lines changed: 1181 additions & 37 deletions

src/Graph.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import List, Tuple
2+
3+
class Graph:
4+
def __init__(self):
5+
self._nodes:List = []
6+
self._edges:List[Tuple] = []
7+
8+
@property
9+
def roots(self) -> List:
10+
roots:List = []
11+
for node in self._nodes:
12+
isChild = False
13+
for edge in self._edges:
14+
if edge[-1] == node:
15+
isChild = True
16+
continue
17+
if isChild == False:
18+
roots.append(node)
19+
return roots.copy()
20+
21+
@property
22+
def nodes(self) -> List:
23+
return self._nodes.copy()
24+
25+
@property
26+
def edges(self) -> List:
27+
return self._edges.copy()
28+
29+
@nodes.setter
30+
def nodes(self, nodes):
31+
self._nodes = list(nodes)
32+
33+
@edges.setter
34+
def edges(self, edges):
35+
self._edges = [tuple(e) for e in edges]
36+
37+
def add_node(self, node):
38+
if node not in self._nodes:
39+
self._nodes.append(node)
40+
41+
def add_edge(self, source, destination):
42+
if source not in self._nodes:
43+
self.add_node(source)
44+
if destination not in self._nodes:
45+
self.add_node(destination)
46+
if (source, destination) not in self._edges:
47+
self._edges.append((source, destination))
48+
49+
def get_connections(self):
50+
connections = {node: [] for node in self._nodes}
51+
for source, destination in self._edges:
52+
connections[source].append(destination)
53+
return connections
54+
55+
def getParentNodes(self) -> List:
56+
return [edge[0] for edge in self._edges]
57+
58+
def getChildNodes(self) -> List:
59+
return [edge[-1] for edge in self._edges]
60+
61+
#Necesita una revisión pero por ahora hace lo que necesito
62+
def prune_to_longest_paths(self):
63+
connections = self.get_connections()
64+
roots = [n for n in self._nodes if n not in self.getChildNodes()]
65+
66+
longest_paths = []
67+
68+
def dfs(node, path):
69+
path = path + [node]
70+
if node not in connections or not connections[node]:
71+
longest_paths.append(path)
72+
return
73+
for child in connections[node]:
74+
dfs(child, path)
75+
76+
for root in roots:
77+
dfs(root, [])
78+
79+
leaf_to_path = {}
80+
for path in longest_paths:
81+
leaf = path[-1]
82+
if leaf not in leaf_to_path or len(path) > len(leaf_to_path[leaf]):
83+
leaf_to_path[leaf] = path
84+
85+
new_nodes = set()
86+
new_edges = set()
87+
for path in leaf_to_path.values():
88+
new_nodes.update(path)
89+
new_edges.update([(path[i], path[i+1]) for i in range(len(path)-1)])
90+
91+
self._nodes = list(new_nodes)
92+
self._edges = list(new_edges)
93+
94+
def __str__(self):
95+
return f"Graph(Nodes: {self._nodes},\n Edges: {self._edges})"

src/ShapesClassification.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
from enum import Enum
22
import json
3+
import math
4+
from time import sleep
35
from typing import Any, Tuple, List, Dict
46

57
import gmsh
8+
9+
from src.Graph import Graph
610
from .BoundingBox import BoundingBox
711
from itertools import chain
812
import numpy as np
913

1014
class ShapesClassification:
15+
_ROUND_VALUE:int = 6
16+
1117
isOpenCase:bool
1218
crossSectionData: Dict
1319
pecs: Dict
1420
dielectrics: Dict
15-
21+
nestedGraph: Graph
1622

1723
def __init__(self, shapes, jsonFile:str):
1824
gmsh.model.occ.synchronize()
@@ -25,7 +31,7 @@ def __init__(self, shapes, jsonFile:str):
2531
self.dielectrics = self.get_dielectrics(shapes)
2632
self.shieldReference = dict()
2733
self.vacuum = dict()
28-
34+
self.nestedGraph = self.__getNestedGraph()
2935
self.isOpenCase = self.isOpenProblem()
3036

3137

@@ -35,7 +41,7 @@ def getNumberFromName(entity_name: str, label: str):
3541
num = int(entity_name[ini:])
3642
return num
3743

38-
def get_pecs(self, entity_tags):
44+
def get_pecs(self, entity_tags) -> Dict[str, Dict[str,any]]:
3945
pecNames = self.__getGeometryNamesByMaterialType('PEC')
4046
pecs = dict()
4147
for s in entity_tags:
@@ -46,7 +52,7 @@ def get_pecs(self, entity_tags):
4652

4753
return pecs
4854

49-
def get_dielectrics(self, entity_tags):
55+
def get_dielectrics(self, entity_tags) -> Dict[str, Dict[str,any]]:
5056
dielectricNames = self.__getGeometryNamesByMaterialType('Dielectric')
5157
dielectrics = dict()
5258
for s in entity_tags:
@@ -64,30 +70,11 @@ def __getGeometryNamesByMaterialType(self, materialType:str) -> List[str]:
6470
if geometry['material']['type'] == materialType
6571
]
6672
return names
67-
68-
def isOpenProblem(self):
69-
elements = list(chain(self.pecs.values()))
70-
isOpenCase = True
71-
for idx, element in enumerate(elements):
72-
intersectWithAll = True
73-
intersect = []
74-
for otheridx, otherElement in enumerate(elements):
75-
if element != otherElement:
76-
intersect = gmsh.model.occ.intersect(
77-
element,
78-
otherElement,
79-
removeObject=False,
80-
tag=300+otheridx,
81-
removeTool=False
82-
)[0]
83-
if len(intersect) == 0:
84-
intersectWithAll = False
85-
else:
86-
isOpenCase = False
87-
if intersectWithAll:
88-
print(element, otherElement)
89-
self.shieldReference = {list(self.pecs.keys())[idx] : element}
90-
return isOpenCase
73+
74+
def isOpenProblem(self) -> None:
75+
if len(self.nestedGraph.roots) > 1:
76+
return True
77+
return False
9178

9279
def removeConductorsFromDielectrics(self):
9380
for num, diel in self.dielectrics.items():
@@ -208,4 +195,26 @@ def _buildDefaultVacuumDomain(self):
208195

209196
return dict([[0, nearVacuum], [1, farVacuum]])
210197

211-
198+
def __getNestedGraph(self):
199+
gmsh.model.occ.synchronize()
200+
graph = Graph()
201+
for key in self.pecs.keys():
202+
graph.add_node(key)
203+
for i, keyA in enumerate(self.pecs.keys()):
204+
for j, keyB in enumerate(self.pecs.keys()):
205+
if i < j:
206+
inter = gmsh.model.occ.intersect(
207+
self.pecs[keyA],
208+
self.pecs[keyB],
209+
removeObject=False,
210+
removeTool=False
211+
)
212+
if len(inter[1][0]) == 0: #comprueba las intersecciones en las que interfiere el objeto
213+
continue
214+
else:
215+
if inter[1][0] == self.pecs[keyA]:
216+
graph.add_edge(keyB, keyA)
217+
elif inter[1][0] == self.pecs[keyB]:
218+
graph.add_edge(keyA, keyB)
219+
graph.prune_to_longest_paths()
220+
return graph

test/test_ShapesClassification.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from copy import copy
12
import os
3+
from typing import Dict, List, Tuple
24
import unittest
35
import gmsh
46
import json
57

8+
69
from src.ShapesClassification import ShapesClassification
710

811

@@ -52,15 +55,21 @@ def testDielectricShieldedPairClassification(self) -> None:
5255
'RightDielectric': [(2,4)],
5356
'LeftDielectric': [(2,5)],
5457
}
55-
expectedShieldReference = {
56-
'ExternalShield': [(2,2)],
57-
}
5858
self.assertListEqual(self.shapeClassification.allShapes, expectedShapes)
5959
self.assertDictEqual(self.shapeClassification.pecs, expectedPecs)
6060
self.assertDictEqual(self.shapeClassification.dielectrics, expectedDielectrics)
61-
self.assertDictEqual(self.shapeClassification.shieldReference, expectedShieldReference)
6261
self.assertFalse(self.shapeClassification.isOpenCase)
6362

63+
def testFusedConductors(self) -> None:
64+
case = 'FusedConductor'
65+
filepath = self.inputFileFromCaseName(case)
66+
self.initShapeClassification(filepath)
67+
68+
def testComplexNesting(self) -> None:
69+
case = 'ComplexNesting'
70+
filepath = self.inputFileFromCaseName(case)
71+
self.initShapeClassification(filepath)
72+
6473
def testDielectricUnshieldedPairClassification(self) -> None:
6574
case = 'DielectricUnshieldedPair'
6675
filepath = self.inputFileFromCaseName(case)
@@ -71,16 +80,14 @@ def testDielectricUnshieldedPairClassification(self) -> None:
7180
(0, 1),(0, 1),(0, 2),(0, 2),(0, 3),(0, 3),(0, 4),(0, 4),
7281
]
7382
expectedPecs = {
74-
'RightConductor': [(2,1)],
75-
'LeftConductor': [(2,2)],
83+
'LeftConductor': [(2, 2)],
84+
'RightConductor': [(2, 1)],
7685
}
7786
expectedDielectrics = {
7887
'RightDielectric': [(2,3)],
7988
'LeftDielectric': [(2,4)],
8089
}
81-
expectedShieldReference = {}
8290
self.assertListEqual(self.shapeClassification.allShapes, expectedShapes)
8391
self.assertDictEqual(self.shapeClassification.pecs, expectedPecs)
8492
self.assertDictEqual(self.shapeClassification.dielectrics, expectedDielectrics)
85-
self.assertDictEqual(self.shapeClassification.shieldReference, expectedShieldReference)
8693
self.assertTrue(self.shapeClassification.isOpenCase)

test/test_graph.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import unittest
2+
3+
from src.Graph import Graph
4+
5+
class TestGraph(unittest.TestCase):
6+
7+
def setUp(self) -> None:
8+
self.graph = Graph()
9+
10+
def tearDown(self):
11+
del self.graph
12+
13+
def test_addNode(self) -> None:
14+
self.graph.add_node('A')
15+
self.assertIn('A', self.graph.nodes)
16+
17+
self.graph.add_node('A')
18+
self.assertEqual(self.graph.nodes.count('A'), 1)
19+
20+
def test_addEdge(self) -> None:
21+
self.graph.add_edge('A', 'B')
22+
self.assertIn(('A', 'B'), self.graph.edges)
23+
24+
self.assertIn('A', self.graph.nodes)
25+
self.assertIn('B', self.graph.nodes)
26+
27+
self.graph.add_edge('A', 'B')
28+
self.assertEqual(self.graph.edges.count(('A', 'B')), 1)
29+
30+
def test_settersAndGetters(self) -> None:
31+
nodes = ['X', 'Y']
32+
edges = [('X', 'Y')]
33+
self.graph.nodes = nodes
34+
self.graph.edges = edges
35+
self.assertEqual(self.graph.nodes, nodes)
36+
self.assertEqual(self.graph.edges, edges)
37+
38+
def test_GetConnections(self) -> None:
39+
self.graph.add_edge('A', 'B')
40+
self.graph.add_edge('A', 'C')
41+
self.graph.add_node('D') # no connections
42+
connections = self.graph.get_connections()
43+
expected = {
44+
'A': ['B', 'C'],
45+
'B': [],
46+
'C': [],
47+
'D': []
48+
}
49+
self.assertEqual(connections, expected)
50+
51+
def test_str(self) -> None:
52+
self.graph.add_edge('A', 'B')
53+
s = str(self.graph)
54+
self.assertIn('A', s)
55+
self.assertIn('B', s)
56+
self.assertIn('Edges', s)
57+
58+
def testPruneToLongestPaths(self) -> None:
59+
self.graph.nodes = ['A' ,'B', 'C', 'D', 'E', 'F', 'G']
60+
self.graph.edges = [
61+
('A', 'B'), ('A', 'C'), ('A', 'D'), ('A', 'E'),
62+
('B', 'C'), ('B', 'E'),
63+
('C', 'E'),
64+
('F', 'G')
65+
]
66+
67+
expectedEdges = [
68+
('A', 'B'), ('A', 'D'),
69+
('B', 'C'),
70+
('C', 'E'),
71+
('F', 'G')
72+
]
73+
74+
self.graph.prune_to_longest_paths()
75+
self.assertListEqual(sorted(self.graph.edges), sorted(expectedEdges))
76+
77+
def testGetRoots(self) -> None:
78+
self.graph.nodes = ['A' ,'B', 'C', 'D', 'E', 'F', 'G']
79+
self.graph.edges = [
80+
('A', 'B'), ('A', 'D'),
81+
('B', 'C'),
82+
('C', 'E'),
83+
('F', 'G')
84+
]
85+
self.assertListEqual(self.graph.roots, ['A', 'F'])
86+
87+
if __name__ == '__main__':
88+
unittest.main()
22 KB
Binary file not shown.

0 commit comments

Comments
 (0)