Skip to content

Commit c4d4fa8

Browse files
committed
Make gates, sources, and sinks sets instead of lists, so it's easier to perform unions and intersections. Add more useful sets of Gates like inputs, registers, etc.
1 parent 31c9e3f commit c4d4fa8

2 files changed

Lines changed: 276 additions & 18 deletions

File tree

pyrtl/gate_graph.py

Lines changed: 221 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ class Gate:
266266
['a', 'b', 'c']
267267
"""
268268

269-
name: str
269+
name: str | None
270270
"""Name of the operation's output :class:`.WireVector`.
271271
272272
Corresponds to :attr:`.WireVector.name`.
@@ -292,7 +292,7 @@ class Gate:
292292
'ab'
293293
"""
294294

295-
bitwidth: int
295+
bitwidth: int | None
296296
"""Bitwidth of the operation's output :class:`.WireVector`.
297297
298298
Corresponds to :attr:`.WireVector.bitwidth`.
@@ -588,6 +588,8 @@ class GateGraph:
588588
:class:`GateGraph`'s constructor creates :class:`Gates<Gate>` from a
589589
:class:`.Block`.
590590
591+
See :ref:`gate_motivation` for more background.
592+
591593
Users should generally construct :class:`GateGraphs<GateGraph>`, rather than
592594
attempting to directly construct individual :class:`Gates<Gate>`. :class:`Gate`
593595
construction is complex because they are doubly-linked, and the :class:`Gate` graph
@@ -669,11 +671,156 @@ class GateGraph:
669671
you started.
670672
"""
671673

672-
gates: list[Gate]
673-
"""A :class:`list` of all :class:`Gates<Gate>` in the ``GateGraph``."""
674+
gates: set[Gate]
675+
"""A :class:`set` of all :class:`Gates<Gate>` in the ``GateGraph``.
676+
677+
.. doctest only::
678+
679+
>>> import pyrtl
680+
>>> pyrtl.reset_working_block()
681+
682+
Example::
683+
684+
>>> a = pyrtl.Input(name="a", bitwidth=1)
685+
>>> b = pyrtl.Input(name="b", bitwidth=1)
686+
>>> c = pyrtl.Input(name="c", bitwidth=1)
687+
>>> x = a & b
688+
>>> x.name = "x"
689+
>>> y = x | c
690+
>>> y.name = "y"
691+
692+
>>> gate_graph = pyrtl.GateGraph()
693+
694+
>>> sorted(gate.name for gate in gate_graph.gates)
695+
['a', 'b', 'c', 'x', 'y']
696+
"""
697+
698+
consts: set[Gate]
699+
"""A :class:`set` of :class:`.Const` :class:`Gates<Gate>` in the ``GateGraph``.
700+
701+
:class:`Gates<Gate>` that provide constant values, with :attr:`~Gate.op` ``C``.
702+
703+
.. doctest only::
704+
705+
>>> import pyrtl
706+
>>> pyrtl.reset_working_block()
707+
708+
Example::
709+
710+
>>> c = pyrtl.Const(name="c", val=0)
711+
>>> d = pyrtl.Const(name="d", val=1)
712+
>>> _ = c + d
713+
714+
>>> gate_graph = pyrtl.GateGraph()
715+
716+
>>> sorted(gate.name for gate in gate_graph.consts)
717+
['c', 'd']
718+
"""
719+
720+
inputs: set[Gate]
721+
"""A :class:`set` of :class:`.Input` :class:`Gates<Gate>` in the ``GateGraph``.
722+
723+
:class:`Gates<Gate>` that provide :class:`.Input` values, with :attr:`~Gate.op`
724+
``I``.
725+
726+
.. doctest only::
727+
728+
>>> import pyrtl
729+
>>> pyrtl.reset_working_block()
730+
731+
Example::
732+
733+
>>> a = pyrtl.Input(name="a", bitwidth=1)
734+
>>> b = pyrtl.Input(name="b", bitwidth=1)
735+
>>> _ = a & b
736+
737+
>>> gate_graph = pyrtl.GateGraph()
738+
739+
>>> sorted(gate.name for gate in gate_graph.inputs)
740+
['a', 'b']
741+
"""
742+
743+
outputs: set[Gate]
744+
"""A :class:`set` of :class:`.Output` :class:`Gates<Gate>` in the ``GateGraph``.
674745
675-
sources: list[Gate]
676-
"""A :class:`list` of all ``source`` :class:`Gates<Gate>` in the ``GateGraph``.
746+
:class:`Gates<Gate>` that set :class:`.Output` values, with :attr:`~Gate.is_output`
747+
``True``.
748+
749+
.. doctest only::
750+
751+
>>> import pyrtl
752+
>>> pyrtl.reset_working_block()
753+
754+
Example::
755+
756+
>>> x = pyrtl.Output(name="x")
757+
>>> y = pyrtl.Output(name="y")
758+
>>> x <<= 42
759+
>>> y <<= 255
760+
761+
>>> gate_graph = pyrtl.GateGraph()
762+
763+
>>> sorted(gate.name for gate in gate_graph.outputs)
764+
['x', 'y']
765+
"""
766+
767+
registers: set[Gate]
768+
"""A :class:`set` of :class:`.Register` update :class:`Gates<Gate>` in the
769+
``GateGraph``.
770+
771+
:class:`Gates<Gate>` that set each :class:`.Register`'s value for the next cycle,
772+
with :attr:`~Gate.op` ``r``.
773+
774+
.. doctest only::
775+
776+
>>> import pyrtl
777+
>>> pyrtl.reset_working_block()
778+
779+
Example::
780+
781+
>>> r = pyrtl.Register(name="r", bitwidth=1)
782+
>>> s = pyrtl.Register(name="s", bitwidth=1)
783+
>>> r.next <<= r + 1
784+
>>> s.next <<= s + 2
785+
786+
>>> gate_graph = pyrtl.GateGraph()
787+
788+
>>> sorted(gate.name for gate in gate_graph.registers)
789+
['r', 's']
790+
"""
791+
792+
memories: set[Gate]
793+
"""A :class:`set` of :class:`.MemBlock` read or write :class:`Gates<Gate>` in the
794+
``GateGraph``.
795+
796+
:class:`Gates<Gate>` that read or write :class:`MemBlocks<.MemBlock`, with
797+
:attr:`~Gate.op` ``m`` or ``@``.
798+
799+
.. doctest only::
800+
801+
>>> import pyrtl
802+
>>> pyrtl.reset_working_block()
803+
804+
Example::
805+
806+
>>> mem = pyrtl.MemBlock(name="mem", bitwidth=4, addrwidth=2)
807+
>>> addr = pyrtl.Input(name="addr", bitwidth=2)
808+
>>> mem[addr] <<= 7
809+
>>> mem_read = mem[addr]
810+
>>> mem_read.name = "mem_read"
811+
812+
>>> gate_graph = pyrtl.GateGraph()
813+
814+
>>> # MemBlock writes have no name.
815+
>>> sorted(str(gate.name) for gate in gate_graph.memories)
816+
['None', 'mem_read']
817+
818+
>>> sorted(gate.op for gate in gate_graph.memories)
819+
['@', 'm']
820+
"""
821+
822+
sources: set[Gate]
823+
"""A :class:`set` of ``source`` :class:`Gates<Gate>` in the ``GateGraph``.
677824
678825
A ``source`` :class:`Gate`'s output value is known at the beginning of each clock
679826
cycle. :class:`Consts<.Const>`, :class:`Inputs<.Input>`, and
@@ -684,10 +831,27 @@ class GateGraph:
684831
:class:`Registers<.Register>` are both ``sources`` and ``sinks``. As a
685832
``source``, it provides the :class:`.Register`'s value for the current cycle. As
686833
a ``sink``, it determines the :class:`.Register`'s value for the next cycle.
834+
835+
.. doctest only::
836+
837+
>>> import pyrtl
838+
>>> pyrtl.reset_working_block()
839+
840+
Example::
841+
842+
>>> a = pyrtl.Input(name="a", bitwidth=1)
843+
>>> c = pyrtl.Const(name="c", bitwidth=1, val=0)
844+
>>> r = pyrtl.Register(name="r", bitwidth=1)
845+
>>> r.next <<= a + c
846+
847+
>>> gate_graph = pyrtl.GateGraph()
848+
849+
>>> sorted(gate.name for gate in gate_graph.sources)
850+
['a', 'c', 'r']
687851
"""
688852

689-
sinks: list[Gate]
690-
"""A list of all ``sink`` :class:`Gates<Gate>` in the ``GateGraph``.
853+
sinks: set[Gate]
854+
"""A :class:`set` of ``sink`` :class:`Gates<Gate>` in the ``GateGraph``.
691855
692856
A ``sink`` :class:`Gate`'s output value is known only at the end of each clock
693857
cycle. :class:`Registers<.Register>`, :class:`Outputs<.Output>` and any
@@ -698,6 +862,26 @@ class GateGraph:
698862
:class:`Registers<.Register>` are both ``sources`` and ``sinks``. As a
699863
``source``, it provides the :class:`.Register`'s value for the current cycle. As
700864
a ``sink``, it determines the :class:`.Register`'s value for the next cycle.
865+
866+
.. doctest only::
867+
868+
>>> import pyrtl
869+
>>> pyrtl.reset_working_block()
870+
871+
Example::
872+
873+
>>> a = pyrtl.Input(name="a", bitwidth=1)
874+
>>> r = pyrtl.Register(name="r", bitwidth=1)
875+
>>> o = pyrtl.Output(name="o", bitwidth=1)
876+
>>> r.next <<= a + 1
877+
>>> o <<= 1
878+
>>> sum = a + r
879+
>>> sum.name = "sum"
880+
881+
>>> gate_graph = pyrtl.GateGraph()
882+
883+
>>> sorted(gate.name for gate in gate_graph.sinks)
884+
['o', 'r', 'sum']
701885
"""
702886

703887
def __init__(self, block: Block = None):
@@ -709,9 +893,14 @@ def __init__(self, block: Block = None):
709893
:param block: :class:`.Block` to construct the :class:`GateGraph` from. Defaults
710894
to the :ref:`working_block`.
711895
"""
712-
self.gates = []
713-
self.sources = []
714-
self.sinks = []
896+
self.gates = set()
897+
self.consts = set()
898+
self.inputs = set()
899+
self.outputs = set()
900+
self.registers = set()
901+
self.memories = set()
902+
self.sources = set()
903+
self.sinks = set()
715904

716905
block = working_block(block)
717906
block.sanity_check()
@@ -730,10 +919,17 @@ def __init__(self, block: Block = None):
730919
wire_vector_map: dict[WireVector, Gate] = {}
731920
for wire_vector in block.wirevector_subset((Const, Input, Register)):
732921
gate = Gate(wire_vector=wire_vector)
733-
self.gates.append(gate)
734-
self.sources.append(gate)
922+
self.gates.add(gate)
923+
self.sources.add(gate)
735924
wire_vector_map[wire_vector] = gate
736925

926+
if gate.op == "C":
927+
self.consts.add(gate)
928+
elif gate.op == "I":
929+
self.inputs.add(gate)
930+
elif gate.op == "R":
931+
self.registers.add(gate)
932+
737933
# In the second phase, we construct all remaining ``Gates`` from ``LogicNets``.
738934
# ``Block``'s iterator returns ``LogicNets`` in topological order, so we can be
739935
# sure that each ``LogicNet``'s ``args`` are all in ``wire_vector_map``.
@@ -752,10 +948,10 @@ def __init__(self, block: Block = None):
752948
gate = wire_vector_map[logic_net.dests[0]]
753949
gate.op = "r"
754950
gate.args = gate_args
755-
self.sinks.append(gate)
951+
self.sinks.add(gate)
756952
else:
757953
gate = Gate(logic_net=logic_net, args=gate_args)
758-
self.gates.append(gate)
954+
self.gates.add(gate)
759955

760956
# Add the new ``Gate`` as a ``dest`` for its ``args``.
761957
for gate_arg in gate_args:
@@ -771,11 +967,16 @@ def __init__(self, block: Block = None):
771967
dest = logic_net.dests[0]
772968
wire_vector_map[dest] = gate
773969

970+
if gate.is_output:
971+
self.outputs.add(gate)
972+
if gate.op in "m@":
973+
self.memories.add(gate)
974+
774975
for gate in self.gates:
775976
if len(gate.dests) == 0:
776-
self.sinks.append(gate)
977+
self.sinks.add(gate)
777978

778-
def get_gate(self, name: str) -> Gate:
979+
def get_gate(self, name: str) -> Gate | None:
779980
"""Return the :class:`Gate` whose :attr:`~Gate.name` is ``name``, or ``None`` if
780981
no such :class:`Gate` exists.
781982
@@ -800,5 +1001,7 @@ def __str__(self) -> str:
8001001
This returns a string representation of each :class:`Gate` in the ``GateGraph``,
8011002
one :class:`Gate` per line. The :class:`Gates<Gate>` will be sorted by name.
8021003
"""
803-
sorted_gates = sorted(self.gates, key=lambda gate: gate.name)
1004+
sorted_gates = sorted(
1005+
self.gates, key=lambda gate: gate.name if gate.name else "~~~"
1006+
)
8041007
return "\n".join([str(gate) for gate in sorted_gates])

tests/test_gate_graph.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,61 @@ def test_memblock(self):
254254
f"[memid={mem.id} mem=mem]",
255255
)
256256

257+
def test_gate_sets(self):
258+
a = pyrtl.Input(name="a", bitwidth=1)
259+
b = pyrtl.Input(name="b", bitwidth=1)
260+
261+
c = pyrtl.Const(name="c", bitwidth=1, val=0)
262+
d = pyrtl.Const(name="d", bitwidth=1, val=1)
263+
264+
x = pyrtl.Output(name="x", bitwidth=1)
265+
y = pyrtl.Output(name="y", bitwidth=1)
266+
267+
r = pyrtl.Register(name="r", bitwidth=1)
268+
s = pyrtl.Register(name="s", bitwidth=1)
269+
270+
mem = pyrtl.MemBlock(name="mem", bitwidth=1, addrwidth=1)
271+
272+
x <<= a + c
273+
274+
r.next <<= r + c
275+
s.next <<= r + d
276+
277+
mem[a] <<= pyrtl.MemBlock.EnabledWrite(data=c, enable=d)
278+
read = mem[b]
279+
read.name = "read"
280+
y <<= read + d
281+
282+
gate_graph = pyrtl.GateGraph()
283+
284+
self.assertEqual(sorted(gate.name for gate in gate_graph.inputs), ["a", "b"])
285+
self.assertEqual(sorted(gate.name for gate in gate_graph.consts), ["c", "d"])
286+
self.assertEqual(sorted(gate.name for gate in gate_graph.outputs), ["x", "y"])
287+
self.assertEqual(sorted(gate.name for gate in gate_graph.registers), ["r", "s"])
288+
memories = gate_graph.memories
289+
self.assertEqual(
290+
sorted(str(gate.name) for gate in memories),
291+
# MemBlock write has no name.
292+
["None", "read"],
293+
)
294+
# Check the MemBlock write.
295+
write_gate = None
296+
for mem_gate in memories:
297+
if not mem_gate.name:
298+
write_gate = mem_gate
299+
break
300+
self.assertTrue(write_gate is not None)
301+
self.assertEqual(write_gate.op, "@")
302+
303+
self.assertEqual(
304+
sorted(gate.name for gate in gate_graph.sources),
305+
["a", "b", "c", "d", "r", "s"],
306+
)
307+
sinks = set(gate_graph.sinks)
308+
self.assertEqual(
309+
sorted(str(gate.name) for gate in sinks), ["None", "r", "s", "x", "y"]
310+
)
311+
257312

258313
if __name__ == "__main__":
259314
unittest.main()

0 commit comments

Comments
 (0)