Skip to content

Commit 506fe96

Browse files
committed
Initial StateRegister implementation
1 parent 2a1aa8a commit 506fe96

9 files changed

Lines changed: 175 additions & 42 deletions

File tree

docs/regmem.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ Registers
99
:show-inheritance:
1010
:special-members: __init__
1111

12+
.. autoclass:: pyrtl.StateRegister
13+
:members:
14+
:show-inheritance:
15+
:special-members: __init__
16+
1217
Memories
1318
--------
1419

examples/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
PYTHON=python
1+
PYTHON=uv run
22
PY_FILES=$(wildcard *.py)
33
IPYNB_FILES=$(addprefix ../ipynb-examples/, $(PY_FILES:.py=.ipynb))
44

examples/example3-statemachine.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
dispense = pyrtl.Output(1, "dispense")
1616
refund = pyrtl.Output(1, "refund")
1717

18-
state = pyrtl.Register(3, "state")
1918

20-
21-
# First new step, let's enumerate a set of constants to serve as our states
19+
# First new step, let's enumerate a set of constants for all possible states.
2220
class State(enum.IntEnum):
2321
WAIT = 0 # Waiting for first token.
2422
TOK1 = 1 # Received first token, waiting for second token.
@@ -28,6 +26,11 @@ class State(enum.IntEnum):
2826
RFND = 5 # Issue refund.
2927

3028

29+
# Define a `StateRegister`, which is just like a `Register`, except that it calculates
30+
# the `Register`'s bitwidth from the largest possible `State`. `StateRegister`s also
31+
# display state names in traces by default.
32+
state = pyrtl.StateRegister(State, "state")
33+
3134
# Now we could build a state machine using just the `Registers` and logic discussed in
3235
# prior examples, but doing operations **conditionally** on some input is a pretty
3336
# fundamental operation in hardware design. PyRTL provides `conditional_assignment` to
@@ -114,11 +117,9 @@ class State(enum.IntEnum):
114117
sim.step_multiple(sim_inputs)
115118

116119
# Also, to make our input/output easy to reason about let's specify an order to the
117-
# traces with `trace_list`. We also use `enum_name` to display the state names (`WAIT`,
118-
# `TOK1`, ...) rather than their numbers (0, 1, ...).
120+
# traces with `trace_list`.
119121
sim.tracer.render_trace(
120-
trace_list=["token_in", "req_refund", "state", "dispense", "refund"],
121-
repr_per_name={"state": pyrtl.enum_name(State)},
122+
trace_list=["token_in", "req_refund", "state", "dispense", "refund"]
122123
)
123124

124125
# Finally, suppose you want to simulate your design and verify its output matches your

ipynb-examples/example3-statemachine.ipynb

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,14 @@
4848
"req_refund = pyrtl.Input(1, \"req_refund\")\n",
4949
"\n",
5050
"dispense = pyrtl.Output(1, \"dispense\")\n",
51-
"refund = pyrtl.Output(1, \"refund\")\n",
52-
"\n",
53-
"state = pyrtl.Register(3, \"state\")\n"
51+
"refund = pyrtl.Output(1, \"refund\")\n"
5452
]
5553
},
5654
{
5755
"cell_type": "markdown",
5856
"metadata": {},
5957
"source": [
60-
" First new step, let's enumerate a set of constants to serve as our states\n"
58+
" First new step, let's enumerate a set of constants for all possible states.\n"
6159
]
6260
},
6361
{
@@ -77,6 +75,26 @@
7775
" RFND = 5 # Issue refund.\n"
7876
]
7977
},
78+
{
79+
"cell_type": "markdown",
80+
"metadata": {},
81+
"source": [
82+
" Define a `StateRegister`, which is just like a [Register](https://pyrtl.readthedocs.io/en/latest/basic.html#pyrtl.Register), except that it calculates\n",
83+
" the [Register](https://pyrtl.readthedocs.io/en/latest/basic.html#pyrtl.Register)'s bitwidth from the largest possible `State`. `StateRegister`s also\n",
84+
" display state names in traces by default.\n"
85+
]
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": null,
90+
"metadata": {
91+
"collapsed": true
92+
},
93+
"outputs": [],
94+
"source": [
95+
"state = pyrtl.StateRegister(State, \"state\")\n"
96+
]
97+
},
8098
{
8199
"cell_type": "markdown",
82100
"metadata": {},
@@ -216,8 +234,7 @@
216234
"metadata": {},
217235
"source": [
218236
" Also, to make our input/output easy to reason about let's specify an order to the\n",
219-
" traces with `trace_list`. We also use `enum_name` to display the state names (`WAIT`,\n",
220-
" `TOK1`, ...) rather than their numbers (0, 1, ...).\n"
237+
" traces with `trace_list`.\n"
221238
]
222239
},
223240
{
@@ -229,8 +246,7 @@
229246
"outputs": [],
230247
"source": [
231248
"sim.tracer.render_trace(\n",
232-
" trace_list=[\"token_in\", \"req_refund\", \"state\", \"dispense\", \"refund\"],\n",
233-
" repr_per_name={\"state\": pyrtl.enum_name(State)},\n",
249+
" trace_list=[\"token_in\", \"req_refund\", \"state\", \"dispense\", \"refund\"]\n",
234250
")\n"
235251
]
236252
},

pyrtl/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
# convenience classes for building hardware
20-
from .wire import WireVector, Input, Output, Const, Register
20+
from .wire import WireVector, Input, Output, Const, Register, StateRegister
2121

2222
from .gate_graph import GateGraph, Gate
2323

@@ -161,6 +161,7 @@
161161
"Output",
162162
"Const",
163163
"Register",
164+
"StateRegister",
164165
# gate_graph
165166
"GateGraph",
166167
"Gate",

pyrtl/simulation.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pyrtl.importexport import _VerilogSanitizer
2121
from pyrtl.memory import MemBlock, RomBlock
2222
from pyrtl.pyrtlexceptions import PyrtlError, PyrtlInternalError
23-
from pyrtl.wire import Const, Input, Output, Register, WireVector
23+
from pyrtl.wire import Const, Input, Output, Register, StateRegister, WireVector
2424

2525
# ----------------------------------------------------------------
2626
# __ ___ __
@@ -1106,6 +1106,8 @@ def invoke_f(f, value):
11061106

11071107
if f is not None:
11081108
return invoke_f(f, value)
1109+
if isinstance(wire, StateRegister):
1110+
return invoke_f(enum_name(wire.States), value)
11091111
return invoke_f(repr_func, value)
11101112

11111113
def render_val(
@@ -1141,20 +1143,18 @@ def render_val(
11411143
_prev_line* fields in RendererConstants.
11421144
:param is_last: If True, current_val is in the last cycle.
11431145
"""
1144-
if len(w) > 1 or w.name in repr_per_name:
1146+
if len(w) > 1 or w.name in repr_per_name or isinstance(w, StateRegister):
11451147
# Render values in boxes for multi-bit wires ("bus"), or single-bit wires
11461148
# with a specific representation.
11471149
#
11481150
# We display multi-wire zero values as a centered horizontal line when a
11491151
# specific `repr_per_name` is not requested for this trace, and a standard
11501152
# numeric format is requested.
1151-
flat_zero = w.name not in repr_per_name and (
1152-
repr_func is hex
1153-
or repr_func is oct
1154-
or repr_func is int
1155-
or repr_func is str
1156-
or repr_func is bin
1157-
or repr_func is val_to_signed_integer
1153+
numeric_formats = [hex, oct, int, str, bin, val_to_signed_integer]
1154+
flat_zero = (
1155+
w.name not in repr_per_name
1156+
and not isinstance(w, StateRegister)
1157+
and repr_func in numeric_formats
11581158
)
11591159
if prev_line:
11601160
# Bus wires are currently never rendered across multiple lines.
@@ -1980,6 +1980,9 @@ def enum_name(EnumClass: type) -> Callable[[int], str]:
19801980
19811981
state FOO│BAR
19821982
1983+
When using ``enum_name`` with a :class:`.Register`, consider using
1984+
:class:`.StateRegister` instead.
1985+
19831986
:param EnumClass: ``enum`` to convert. This is the enum class, like ``State``, not
19841987
an enum value, like ``State.FOO`` or ``1``.
19851988

pyrtl/wire.py

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import enum
1920
import numbers
2021
import re
2122
import traceback
@@ -1661,15 +1662,15 @@ def __ior__(self, _):
16611662

16621663

16631664
class Register(WireVector):
1664-
"""A WireVector with an embedded register state element.
1665+
"""A :class:`WireVector` with an embedded register state element.
16651666
1666-
Registers only update their outputs on the rising edges of an implicit clock signal.
1667-
The "value" in the current cycle can be accessed by referencing the Register itself.
1668-
To set the value for the next cycle (after the next rising clock edge), set the
1669-
:attr:`Register.next` property with the ``<<=`` (:meth:`~WireVector.__ilshift__`)
1670-
operator.
1667+
``Registers`` only update their outputs on the rising edges of an implicit clock
1668+
signal. The "value" in the current cycle can be accessed by referencing the
1669+
``Register`` itself. To set the value for the next cycle (after the next rising
1670+
clock edge), set the :attr:`Register.next` property with the ``<<=``
1671+
(:meth:`~WireVector.__ilshift__`) operator.
16711672
1672-
Registers reset to zero by default, and reside in the same clock domain.
1673+
``Registers`` reset to zero by default, and reside in the same clock domain.
16731674
16741675
.. doctest only::
16751676
@@ -1689,6 +1690,10 @@ class Register(WireVector):
16891690
This builds a zero-initialized 2-bit counter. The second line sets the counter's
16901691
value in the next cycle (``counter.next``) to the counter's value in the current
16911692
cycle (``counter``), plus one.
1693+
1694+
.. note::
1695+
1696+
Consider using :class:`StateRegister` for state machine ``Registers``.
16921697
"""
16931698

16941699
_code = "R"
@@ -1811,20 +1816,20 @@ def __init__(
18111816
reset_value: int | None = None,
18121817
block: Block = None,
18131818
):
1814-
"""Construct a register.
1819+
"""Construct a ``Register``.
18151820
18161821
It is an error if the ``reset_value`` cannot fit into the specified ``bitwidth``
18171822
for this register.
18181823
1819-
:param bitwidth: Number of bits to represent this register.
1820-
:param name: The name of the register's current value (``reg``, not
1824+
:param bitwidth: Number of bits to represent this ``Register``.
1825+
:param name: The name of the ``Register``'s current value (``reg``, not
18211826
``reg.next``). Must be unique. If none is provided, one will be
18221827
autogenerated.
1823-
:param reset_value: Value to initialize this register to during simulation and
1824-
in any code (e.g. Verilog) that is exported. Defaults to 0. Can be
1828+
:param reset_value: Value to initialize this ``Register`` to during simulation
1829+
and in any code (e.g. Verilog) that is exported. Defaults to 0. Can be
18251830
overridden at simulation time.
1826-
:param block: The block under which the wire should be placed. Defaults to the
1827-
:ref:`working_block`.
1831+
:param block: The :class:`Block` under which the wire should be placed. Defaults
1832+
to the :ref:`working_block`.
18281833
"""
18291834
from pyrtl.helperfuncs import infer_val_and_bitwidth
18301835

@@ -1871,6 +1876,74 @@ def _build(self, next):
18711876
working_block().add_net(net)
18721877

18731878

1879+
class StateRegister(Register):
1880+
"""A :class:`Register` containing an :class:`~enum.IntEnum` state.
1881+
1882+
``StateRegister`` functions identically to :class:`Register`, except that the
1883+
:class:`Register`'s bitwidth is calculated from an :class:`~enum.IntEnum`'s largest
1884+
value, and :meth:`~.SimulationTrace.render_trace` displays state names by
1885+
default.
1886+
1887+
.. doctest only::
1888+
1889+
>>> import pyrtl
1890+
>>> pyrtl.reset_working_block()
1891+
1892+
Example::
1893+
1894+
>>> import enum
1895+
>>> class MyStates(enum.IntEnum):
1896+
... ZERO = 0
1897+
... ONE = 1
1898+
... TWO = 2
1899+
... THREE = 3
1900+
1901+
>>> state = pyrtl.StateRegister(name="state", States=MyStates)
1902+
>>> state.bitwidth
1903+
2
1904+
1905+
>>> state.next <<= state + 1
1906+
1907+
>>> sim = pyrtl.Simulation()
1908+
>>> sim.step_multiple(nsteps=4)
1909+
>>> sim.tracer.render_trace()
1910+
1911+
Which prints::
1912+
1913+
│0 │1 │2 │3
1914+
state ZERO │ONE │TWO │THREE
1915+
"""
1916+
1917+
def __init__(
1918+
self,
1919+
States: enum.IntEnum,
1920+
name: str = "",
1921+
reset_value: int | None = None,
1922+
block: Block = None,
1923+
):
1924+
"""Constructs a :class:`Register` containing an :class:`~enum.IntEnum` value.
1925+
1926+
:param States: An :class:`~enum.IntEnum` containing all possible states for the
1927+
``StateRegister``. The largest value in the :class:`~enum.IntEnum`
1928+
determines the :class:`Register`'s :attr:`~WireVector.bitwidth`.
1929+
:param name: The name of the ``StateRegister``'s current value (``state``, not
1930+
``state.next``). Must be unique. If none is provided, one will be
1931+
autogenerated.
1932+
:param reset_value: Value to initialize this ``StateRegister`` to during
1933+
simulation and in any code (e.g. Verilog) that is exported. Defaults to 0.
1934+
Can be overridden at simulation time.
1935+
:param block: The :class:`Block` under which the wire should be placed. Defaults
1936+
to the :ref:`working_block`.
1937+
"""
1938+
from pyrtl.helperfuncs import infer_val_and_bitwidth
1939+
1940+
self.States = States
1941+
bitwidth = infer_val_and_bitwidth(max(States)).bitwidth
1942+
super().__init__(
1943+
bitwidth=bitwidth, name=name, reset_value=reset_value, block=block
1944+
)
1945+
1946+
18741947
class WrappedWireVector:
18751948
"""Wraps a WireVector. Forwards all method calls and attribute accesses.
18761949

tests/test_simulation.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,11 @@ class State(enum.IntEnum):
273273
renderer=self.renderer,
274274
repr_per_name={state.name: pyrtl.enum_name(State)},
275275
)
276-
expected = " |0 |1 \n \nstate FOO|BAR\n"
276+
expected = (
277+
" |0 |1 \n"
278+
" \n"
279+
"state FOO|BAR\n"
280+
) # fmt: skip
277281
self.assertEqual(buff.getvalue(), expected)
278282

279283
def test_val_to_signed_integer(self):
@@ -286,7 +290,11 @@ def test_val_to_signed_integer(self):
286290
sim.tracer.render_trace(
287291
file=buff, renderer=self.renderer, repr_func=pyrtl.val_to_signed_integer
288292
)
289-
expected = " |0 |1 |2 |3 \n \ncounter --|1 |-2|-1\n"
293+
expected = (
294+
" |0 |1 |2 |3 \n"
295+
" \n"
296+
"counter --|1 |-2|-1\n"
297+
) # fmt: skip
290298
self.assertEqual(buff.getvalue(), expected)
291299

292300
def test_custom_repr_per_wire(self):

0 commit comments

Comments
 (0)