-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpetri_net.py
More file actions
338 lines (258 loc) · 9.84 KB
/
Copy pathpetri_net.py
File metadata and controls
338 lines (258 loc) · 9.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
from drampyml.common.commands import Command
import rustworkx as rx
from sympy import Expr, sympify
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Protocol
# Base class for all IP-specific PlaceType enums
class BasePlaceType:
"""Base class providing common string representation for all IP-specific PlaceType enums.
All IP-specific PlaceType enums (PlaceType, UARTPlaceType, etc.) should inherit from this.
"""
def __str__(self):
return self.name # type: ignore
def __repr__(self):
return str(self)
# DDR/Memory-specific PlaceTypes (for DDR3, DDR4, LPDDR4, etc.)
class PlaceType(BasePlaceType, Enum):
"""PlaceType enum for DDR/memory standards."""
ACTIVE = auto()
PDN = auto()
DPD = auto()
PWR_ON = auto()
SREF = auto()
SREF_FLAG = auto()
SRS = auto()
SL = auto()
CMD_BUS = auto()
NAW_Pool = auto()
REF_Flag = auto()
REF_Pool = auto()
CSP = auto()
# UART-specific PlaceTypes
class UARTPlaceType(BasePlaceType, Enum):
"""PlaceType enum for UART standard."""
IDLE = auto()
PARSING = auto()
RD_PENDING = auto()
WR_PENDING = auto()
TX_RESPONSE = auto()
# AES-specific PlaceTypes
class AESPlaceType(BasePlaceType, Enum):
"""PlaceType enum for AES standard."""
WAIT_KEY = auto()
KEY_EXPANDING = auto()
WAIT_DATA = auto()
INITIAL_ROUND = auto()
DO_ROUND = auto()
FINAL_ROUND = auto()
class Coordinate(Protocol): ...
@dataclass(frozen=True)
class Token:
timestamp: int = -(2**32)
@dataclass(unsafe_hash=True)
class Place:
place_type: BasePlaceType
coordinate: Coordinate
tokens: list[Token] = field(default_factory=lambda: [])
@dataclass
class Transition:
command: Command
coordinate: Coordinate
active: bool = False
@dataclass(frozen=True)
class Arc:
weight: int = 1
@dataclass(frozen=True)
class InhibitorArc:
weight: int = 1
@dataclass(frozen=True)
class TimedArc:
weight: int = 1
lower_bound: Expr = sympify(0)
@dataclass
class CustomArc:
time_constraint: Expr
timestamp: int = -(2**32)
def active(self, time, memspec):
return time < self.timestamp + self.time_constraint.subs(memspec)
@dataclass(frozen=True)
class ResetArc:
pass
@dataclass
class PetriNet:
graph: rx.PyDiGraph
memspec: dict[Expr, int] = field(default_factory=lambda: {})
places: dict[tuple[Coordinate, PlaceType], int] = field(init=False)
transitions: dict[tuple[Coordinate, Command], list[int]] = field(init=False)
current_time: int = 0
ignore_timing_constraints = False
def __post_init__(self):
self.evaluate()
self._explore_transitons()
self._explore_places()
def _explore_transitons(self):
self.transitions = {}
for i in self.graph.filter_nodes(lambda node: isinstance(node, Transition)):
data = self.graph[i]
self.transitions.setdefault((data.coordinate, data.command), []).append(i)
def _explore_places(self):
self.places = {}
for i in self.graph.filter_nodes(lambda node: isinstance(node, Place)):
data = self.graph[i]
self.places[(data.coordinate, data.place_type)] = i
def evaluate(self):
for transition in self.graph.filter_nodes(
lambda node: isinstance(node, Transition)
):
node_data = self.graph[transition]
node_data.active = self.can_fire_transition(transition)
def can_fire_transition(self, transition_index: int):
in_edges = self.graph.in_edges(transition_index)
for src_index, _, edge_data in in_edges:
src_node_data = self.graph[src_index]
if isinstance(edge_data, CustomArc):
if not self.ignore_timing_constraints and edge_data.active(
self.current_time, self.memspec
):
return False
continue
src_tokens: list[Token] = src_node_data.tokens
if isinstance(edge_data, TimedArc):
lower_bound = edge_data.lower_bound.subs(self.memspec)
valid_tokens = (
len(src_tokens)
if self.ignore_timing_constraints
else sum(
1
for token in src_tokens
if lower_bound <= (self.current_time - token.timestamp)
)
)
if valid_tokens < edge_data.weight:
return False
elif isinstance(edge_data, Arc) and len(src_tokens) < edge_data.weight:
return False
elif (
isinstance(edge_data, InhibitorArc)
and len(src_tokens) >= edge_data.weight
):
return False
return True
def fire_transition(self, transition_index) -> bool:
if not self.graph[transition_index].active:
return False
in_edges = self.graph.in_edges(transition_index)
out_edges = self.graph.out_edges(transition_index)
for src_index, _, edge_data in in_edges:
src_node_data = self.graph[src_index]
if isinstance(edge_data, TimedArc):
del src_node_data.tokens[: edge_data.weight]
elif isinstance(edge_data, Arc):
del src_node_data.tokens[: edge_data.weight]
for src_index, _, edge_data in in_edges:
src_node_data = self.graph[src_index]
if isinstance(edge_data, ResetArc):
del src_node_data.tokens[:]
for _, dst_index, edge_data in out_edges:
dst_node_data = self.graph[dst_index]
if isinstance(edge_data, Arc) or isinstance(edge_data, TimedArc):
dst_node_data.tokens.extend(
[Token(self.current_time) for _ in range(edge_data.weight)]
)
elif isinstance(edge_data, CustomArc):
edge_data.timestamp = self.current_time
self.evaluate()
return True
def who_can_fire(self) -> set[int]:
return {
transition
for transition in self.graph.filter_nodes(
lambda node: isinstance(node, Transition)
)
if self.can_fire_transition(transition)
}
def who_cant_fire(self) -> set[int]:
return {
transition
for transition in self.graph.filter_nodes(
lambda node: isinstance(node, Transition)
)
if not self.can_fire_transition(transition)
}
def pruned_graph(self) -> rx.PyDiGraph:
# Remove inactive CustomArcs
temp_graph = self.graph.copy()
for edge_idx in temp_graph.edge_indices():
edge_data = temp_graph.get_edge_data_by_index(edge_idx)
if isinstance(edge_data, CustomArc) and not edge_data.active(
self.current_time, self.memspec
):
temp_graph.remove_edge_from_index(edge_idx)
# Remove CMD_BUS and NAW places
for place_idx in temp_graph.filter_nodes(lambda node: isinstance(node, Place)):
place_data: Place = temp_graph[place_idx]
if (
place_data.place_type == PlaceType.CMD_BUS
or place_data.place_type == PlaceType.NAW_Pool
):
temp_graph.remove_node(place_idx)
# Remove REFSB pools and flags
for node_idx in temp_graph.node_indices():
node_data = temp_graph[node_idx]
if (
isinstance(node_data, Place)
and (node_data.place_type == PlaceType.REF_Pool
or node_data.place_type == PlaceType.REF_Flag)
):
temp_graph.remove_node(node_idx)
return temp_graph
def write_dot(self, filename: str):
return self.pruned_graph().to_dot(
filename=filename, node_attr=node_viz, edge_attr=edge_viz
)
def write_img(self, filename: str, **kwargs):
from rustworkx.visualization import graphviz_draw
graphviz_draw(
graph=self.pruned_graph(),
filename=filename,
node_attr_fn=node_viz,
edge_attr_fn=edge_viz,
method="dot",
**kwargs,
)
def node_viz(node_data: Place | Transition):
attributes: dict[str, str] = {}
if isinstance(node_data, Place):
attributes["label"] = "•" * len(node_data.tokens)
attributes["xlabel"] = f"{node_data.place_type}"
attributes["xlabelloc"] = "b"
attributes["shape"] = "circle"
if isinstance(node_data, Transition):
attributes["label"] = f"{node_data.command}"
attributes["shape"] = "rectangle"
attributes["style"] = "filled"
attributes["fillcolor"] = "green" if node_data.active else "red"
return attributes
def edge_viz(edge_data: Arc | InhibitorArc | ResetArc | TimedArc | CustomArc):
attributes: dict[str, str] = {}
if isinstance(edge_data, ResetArc):
attributes["color"] = "red"
attributes["arrowhead"] = "normalnormal"
if isinstance(edge_data, InhibitorArc):
attributes["arrowhead"] = "dot"
if type(edge_data) is Arc or isinstance(edge_data, InhibitorArc):
attributes["label"] = f"{edge_data.weight}" if edge_data.weight > 1 else ""
if isinstance(edge_data, TimedArc):
attributes["color"] = "blue"
attributes["fontcolor"] = "blue"
label = f"[{edge_data.lower_bound},∞["
if edge_data.weight > 1:
label += f"\n{edge_data.weight}"
attributes["label"] = label
if isinstance(edge_data, CustomArc):
attributes["color"] = "blue"
attributes["fontcolor"] = "blue"
attributes["arrowhead"] = "diamond"
attributes["label"] = f"{edge_data.time_constraint}"
return attributes