From 2ff6283384d888f39c8d04ce1b83f930406a72e6 Mon Sep 17 00:00:00 2001 From: Jeremy Lau <30300826+fdxmw@users.noreply.github.com> Date: Sun, 18 Jan 2026 16:45:11 -0800 Subject: [PATCH 1/2] Initial `StateRegister` implementation. A `StateRegister` is just like a `Register`, except its bitwidth is defined by an `IntEnum`, and `render_trace` displays the `IntEnum`'s name instead of its `int` value. This also changes `enum_name` so it converts unknown `int` values to string with `hex`, instead of throwing a `ValueError`. --- docs/basic.rst | 1 + docs/index.rst | 1 + docs/regmem.rst | 5 + examples/Makefile | 2 +- examples/example3-statemachine.py | 15 +-- ipynb-examples/example3-statemachine.ipynb | 32 +++++-- pyrtl/__init__.py | 3 +- pyrtl/simulation.py | 56 ++++++----- pyrtl/wire.py | 102 ++++++++++++++++++--- tests/test_simulation.py | 39 +++++++- tests/test_wire.py | 26 ++++++ 11 files changed, 226 insertions(+), 56 deletions(-) diff --git a/docs/basic.rst b/docs/basic.rst index 04dd40d2..0d42a4e6 100644 --- a/docs/basic.rst +++ b/docs/basic.rst @@ -23,6 +23,7 @@ hard-wired values and :class:`.Register` is how sequential elements are created pyrtl.Output pyrtl.Const pyrtl.Register + pyrtl.StateRegister :parts: 1 WireVector diff --git a/docs/index.rst b/docs/index.rst index 30b3c8f4..542ca5e7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -92,6 +92,7 @@ this is all a user needs to create a functional hardware design. pyrtl.Output pyrtl.Const pyrtl.Register + pyrtl.StateRegister :parts: 1 After specifying a hardware design, there are then options to simulate your diff --git a/docs/regmem.rst b/docs/regmem.rst index 1e3195a7..b85754cd 100644 --- a/docs/regmem.rst +++ b/docs/regmem.rst @@ -9,6 +9,11 @@ Registers :show-inheritance: :special-members: __init__ +.. autoclass:: pyrtl.StateRegister + :members: + :show-inheritance: + :special-members: __init__ + Memories -------- diff --git a/examples/Makefile b/examples/Makefile index 3f0dd36f..71b0c703 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -1,4 +1,4 @@ -PYTHON=uv run python3 +PYTHON=uv run PY_FILES=$(wildcard *.py) IPYNB_FILES=$(addprefix ../ipynb-examples/, $(PY_FILES:.py=.ipynb)) diff --git a/examples/example3-statemachine.py b/examples/example3-statemachine.py index 78a78215..11cd3816 100644 --- a/examples/example3-statemachine.py +++ b/examples/example3-statemachine.py @@ -15,10 +15,8 @@ dispense = pyrtl.Output(1, "dispense") refund = pyrtl.Output(1, "refund") -state = pyrtl.Register(3, "state") - -# First new step, let's enumerate a set of constants to serve as our states +# First new step, let's enumerate a set of constants for all possible states. class State(enum.IntEnum): WAIT = 0 # Waiting for first token. TOK1 = 1 # Received first token, waiting for second token. @@ -28,6 +26,11 @@ class State(enum.IntEnum): RFND = 5 # Issue refund. +# Define a `StateRegister`, which is just like a `Register`, except that it calculates +# the `Register`'s bitwidth from the largest possible `State`. `StateRegister`s also +# display state names in traces by default. +state = pyrtl.StateRegister(State, "state") + # Now we could build a state machine using just the `Registers` and logic discussed in # prior examples, but doing operations **conditionally** on some input is a pretty # fundamental operation in hardware design. PyRTL provides `conditional_assignment` to @@ -114,11 +117,9 @@ class State(enum.IntEnum): sim.step_multiple(sim_inputs) # Also, to make our input/output easy to reason about let's specify an order to the -# traces with `trace_list`. We also use `enum_name` to display the state names (`WAIT`, -# `TOK1`, ...) rather than their numbers (0, 1, ...). +# traces with `trace_list`. sim.tracer.render_trace( - trace_list=["token_in", "req_refund", "state", "dispense", "refund"], - repr_per_name={"state": pyrtl.enum_name(State)}, + trace_list=["token_in", "req_refund", "state", "dispense", "refund"] ) # Finally, suppose you want to simulate your design and verify its output matches your diff --git a/ipynb-examples/example3-statemachine.ipynb b/ipynb-examples/example3-statemachine.ipynb index b2cd00c2..52f18473 100644 --- a/ipynb-examples/example3-statemachine.ipynb +++ b/ipynb-examples/example3-statemachine.ipynb @@ -48,16 +48,14 @@ "req_refund = pyrtl.Input(1, \"req_refund\")\n", "\n", "dispense = pyrtl.Output(1, \"dispense\")\n", - "refund = pyrtl.Output(1, \"refund\")\n", - "\n", - "state = pyrtl.Register(3, \"state\")\n" + "refund = pyrtl.Output(1, \"refund\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - " First new step, let's enumerate a set of constants to serve as our states\n" + " First new step, let's enumerate a set of constants for all possible states.\n" ] }, { @@ -77,6 +75,26 @@ " RFND = 5 # Issue refund.\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Define a `StateRegister`, which is just like a [Register](https://pyrtl.readthedocs.io/en/latest/basic.html#pyrtl.Register), except that it calculates\n", + " the [Register](https://pyrtl.readthedocs.io/en/latest/basic.html#pyrtl.Register)'s bitwidth from the largest possible `State`. `StateRegister`s also\n", + " display state names in traces by default.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "state = pyrtl.StateRegister(State, \"state\")\n" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -216,8 +234,7 @@ "metadata": {}, "source": [ " Also, to make our input/output easy to reason about let's specify an order to the\n", - " traces with `trace_list`. We also use `enum_name` to display the state names (`WAIT`,\n", - " `TOK1`, ...) rather than their numbers (0, 1, ...).\n" + " traces with `trace_list`.\n" ] }, { @@ -229,8 +246,7 @@ "outputs": [], "source": [ "sim.tracer.render_trace(\n", - " trace_list=[\"token_in\", \"req_refund\", \"state\", \"dispense\", \"refund\"],\n", - " repr_per_name={\"state\": pyrtl.enum_name(State)},\n", + " trace_list=[\"token_in\", \"req_refund\", \"state\", \"dispense\", \"refund\"]\n", ")\n" ] }, diff --git a/pyrtl/__init__.py b/pyrtl/__init__.py index ecb6a2b2..de6d3fc8 100644 --- a/pyrtl/__init__.py +++ b/pyrtl/__init__.py @@ -17,7 +17,7 @@ ) # convenience classes for building hardware -from .wire import WireVector, Input, Output, Const, Register +from .wire import WireVector, Input, Output, Const, Register, StateRegister from .gate_graph import GateGraph, Gate @@ -161,6 +161,7 @@ "Output", "Const", "Register", + "StateRegister", # gate_graph "GateGraph", "Gate", diff --git a/pyrtl/simulation.py b/pyrtl/simulation.py index 045aa3b8..bbf7669b 100644 --- a/pyrtl/simulation.py +++ b/pyrtl/simulation.py @@ -20,7 +20,7 @@ from pyrtl.importexport import _VerilogSanitizer from pyrtl.memory import MemBlock, RomBlock from pyrtl.pyrtlexceptions import PyrtlError, PyrtlInternalError -from pyrtl.wire import Const, Input, Output, Register, WireVector +from pyrtl.wire import Const, Input, Output, Register, StateRegister, WireVector # ---------------------------------------------------------------- # __ ___ __ @@ -1106,6 +1106,8 @@ def invoke_f(f, value): if f is not None: return invoke_f(f, value) + if isinstance(wire, StateRegister): + return invoke_f(enum_name(wire.States), value) return invoke_f(repr_func, value) def render_val( @@ -1141,20 +1143,18 @@ def render_val( _prev_line* fields in RendererConstants. :param is_last: If True, current_val is in the last cycle. """ - if len(w) > 1 or w.name in repr_per_name: + if len(w) > 1 or w.name in repr_per_name or isinstance(w, StateRegister): # Render values in boxes for multi-bit wires ("bus"), or single-bit wires # with a specific representation. # # We display multi-wire zero values as a centered horizontal line when a # specific `repr_per_name` is not requested for this trace, and a standard # numeric format is requested. - flat_zero = w.name not in repr_per_name and ( - repr_func is hex - or repr_func is oct - or repr_func is int - or repr_func is str - or repr_func is bin - or repr_func is val_to_signed_integer + numeric_formats = [hex, oct, int, str, bin, val_to_signed_integer] + flat_zero = ( + w.name not in repr_per_name + and not isinstance(w, StateRegister) + and repr_func in numeric_formats ) if prev_line: # Bus wires are currently never rendered across multiple lines. @@ -1956,7 +1956,7 @@ def print_perf_counters(self, *trace_names: str, file=sys.stdout): def enum_name(EnumClass: type) -> Callable[[int], str]: - """Returns a function that returns the name of an :class:`enum.IntEnum` value. + """Returns a function that returns the name of an :class:`~enum.IntEnum` value. .. doctest only:: @@ -1965,32 +1965,44 @@ def enum_name(EnumClass: type) -> Callable[[int], str]: >>> pyrtl.reset_working_block() Use ``enum_name`` as a ``repr_func`` or ``repr_per_name`` for - :meth:`SimulationTrace.render_trace` to display :class:`enum.IntEnum` names in + :meth:`~SimulationTrace.render_trace` to display :class:`~enum.IntEnum` names in traces, instead of their numeric value. Example:: - >>> class State(enum.IntEnum): + >>> class Option(enum.IntEnum): ... FOO = 0 ... BAR = 1 - >>> state = pyrtl.Input(name="state", bitwidth=1) + >>> pyrtl.enum_name(Option)(1) + 'BAR' + + >>> option = pyrtl.Input(name="option", bitwidth=1) >>> sim = pyrtl.Simulation() - >>> sim.step_multiple({"state": [State.FOO, State.BAR]}) - >>> sim.tracer.render_trace(repr_per_name={"state": pyrtl.enum_name(State)}) + >>> sim.step_multiple({"option": [Option.FOO, Option.BAR]}) + >>> sim.tracer.render_trace(repr_per_name={"option": pyrtl.enum_name(Option)}) Which prints:: - │0 │1 + │0 │1 + + option FOO│BAR + + .. note:: - state FOO│BAR + When using ``enum_name`` with a :class:`.Register`, consider using + :class:`.StateRegister` instead. - :param EnumClass: ``enum`` to convert. This is the enum class, like ``State``, not - an enum value, like ``State.FOO`` or ``1``. + :param EnumClass: ``enum`` to convert. This is the enum class, like ``Option``, not + an enum value, like ``Option.FOO`` or ``1``. - :return: A function that accepts an enum value, like ``State.FOO`` or ``1``, and - returns the value's name as a string, like ``"FOO"``. + :return: A function that accepts an enum value, like ``Option.FOO`` or ``1``, and + returns the value's name as a string, like ``"FOO"``. Unknown values will + be converted to string with :class:`hex`. """ def value_to_name(value: int) -> str: - return EnumClass(value).name + try: + return EnumClass(value).name + except ValueError: + return hex(value) return value_to_name diff --git a/pyrtl/wire.py b/pyrtl/wire.py index 8c93e297..22b325e2 100644 --- a/pyrtl/wire.py +++ b/pyrtl/wire.py @@ -16,6 +16,7 @@ from __future__ import annotations +import enum import numbers import re import traceback @@ -1661,15 +1662,15 @@ def __ior__(self, _): class Register(WireVector): - """A WireVector with an embedded register state element. + """A :class:`WireVector` with an embedded register state element. - Registers only update their outputs on the rising edges of an implicit clock signal. - The "value" in the current cycle can be accessed by referencing the Register itself. - To set the value for the next cycle (after the next rising clock edge), set the - :attr:`Register.next` property with the ``<<=`` (:meth:`~WireVector.__ilshift__`) - operator. + ``Registers`` only update their outputs on the rising edges of an implicit clock + signal. The "value" in the current cycle can be accessed by referencing the + ``Register`` itself. To set the value for the next cycle (after the next rising + clock edge), set the :attr:`Register.next` property with the ``<<=`` + (:meth:`~WireVector.__ilshift__`) operator. - Registers reset to zero by default, and reside in the same clock domain. + ``Registers`` reset to zero by default, and reside in the same clock domain. .. doctest only:: @@ -1689,6 +1690,10 @@ class Register(WireVector): This builds a zero-initialized 2-bit counter. The second line sets the counter's value in the next cycle (``counter.next``) to the counter's value in the current cycle (``counter``), plus one. + + .. note:: + + Consider using :class:`StateRegister` for state machine ``Registers``. """ _code = "R" @@ -1811,20 +1816,20 @@ def __init__( reset_value: int | None = None, block: Block = None, ): - """Construct a register. + """Construct a ``Register``. It is an error if the ``reset_value`` cannot fit into the specified ``bitwidth`` for this register. - :param bitwidth: Number of bits to represent this register. - :param name: The name of the register's current value (``reg``, not + :param bitwidth: Number of bits to represent this ``Register``. + :param name: The name of the ``Register``'s current value (``reg``, not ``reg.next``). Must be unique. If none is provided, one will be autogenerated. - :param reset_value: Value to initialize this register to during simulation and - in any code (e.g. Verilog) that is exported. Defaults to 0. Can be + :param reset_value: Value to initialize this ``Register`` to during simulation + and in any code (e.g. Verilog) that is exported. Defaults to 0. Can be overridden at simulation time. - :param block: The block under which the wire should be placed. Defaults to the - :ref:`working_block`. + :param block: The :class:`Block` under which the wire should be placed. Defaults + to the :ref:`working_block`. """ from pyrtl.helperfuncs import infer_val_and_bitwidth @@ -1871,6 +1876,75 @@ def _build(self, next): working_block().add_net(net) +class StateRegister(Register): + """A :class:`Register` containing an :class:`~enum.IntEnum` state. + + ``StateRegister`` functions identically to :class:`Register`, except that the + :class:`Register`'s bitwidth is calculated from the :class:`~enum.IntEnum`'s largest + value, and :meth:`~.SimulationTrace.render_trace` displays state names by default. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> import enum + >>> class MyStates(enum.IntEnum): + ... ZERO = 0 + ... ONE = 1 + ... TWO = 2 + ... THREE = 3 + + >>> state = pyrtl.StateRegister( + ... name="state", States=MyStates, reset_value=MyStates.ONE + ... ) + >>> state.bitwidth + 2 + + >>> state.next <<= state + 1 + + >>> sim = pyrtl.Simulation() + >>> sim.step_multiple(nsteps=4) + >>> sim.tracer.render_trace() + + Which prints:: + + │0 │1 │2 │3 + state ONE │TWO │THREE│ZERO + """ + + def __init__( + self, + States: enum.IntEnum, + name: str = "", + reset_value: enum.IntEnum | int | None = None, + block: Block = None, + ): + """Constructs a :class:`Register` containing an :class:`~enum.IntEnum` value. + + :param States: An :class:`~enum.IntEnum` containing all possible states for the + ``StateRegister``. The largest value in the :class:`~enum.IntEnum` + determines the :class:`Register`'s :attr:`~WireVector.bitwidth`. + :param name: The name of the ``StateRegister``'s current value (``state``, not + ``state.next``). Must be unique. If none is provided, one will be + autogenerated. + :param reset_value: Value to initialize this ``StateRegister`` to during + simulation and in any code (e.g. Verilog) that is exported. Defaults to 0. + Can be overridden at simulation time. + :param block: The :class:`Block` under which the wire should be placed. Defaults + to the :ref:`working_block`. + """ + from pyrtl.helperfuncs import infer_val_and_bitwidth + + self.States = States + bitwidth = infer_val_and_bitwidth(max(States)).bitwidth + super().__init__( + bitwidth=bitwidth, name=name, reset_value=reset_value, block=block + ) + + class WrappedWireVector: """Wraps a ``WireVector``. Forwards all method calls and attribute accesses. diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 63e1e183..0c408901 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -264,6 +264,11 @@ class State(enum.IntEnum): FOO = 0 BAR = 1 + state_name = pyrtl.enum_name(State) + self.assertEqual(state_name(0), "FOO") + self.assertEqual(state_name(1), "BAR") + self.assertEqual(state_name(2), "0x2") + state = pyrtl.Input(name="state", bitwidth=1) sim = pyrtl.Simulation() sim.step_multiple({state.name: [State.FOO, State.BAR]}) @@ -271,9 +276,33 @@ class State(enum.IntEnum): sim.tracer.render_trace( file=buff, renderer=self.renderer, - repr_per_name={state.name: pyrtl.enum_name(State)}, + repr_per_name={state.name: state_name}, ) - expected = " |0 |1 \n \nstate FOO|BAR\n" + expected = ( + " |0 |1 \n" + " \n" + "state FOO|BAR\n" + ) # fmt: skip + self.assertEqual(buff.getvalue(), expected) + + def test_state_register(self): + class State(enum.IntEnum): + A = 0 + B = 1 + C = 2 + D = 3 + + state = pyrtl.StateRegister(name="state", States=State, reset_value=State.B) + state.next <<= state + 1 + sim = pyrtl.Simulation() + sim.step_multiple(nsteps=4) + buff = io.StringIO() + sim.tracer.render_trace(file=buff, renderer=self.renderer) + expected = ( + " |0|1|2|3\n" + " \n" + "state B|C|D|A\n" + ) # fmt: skip self.assertEqual(buff.getvalue(), expected) def test_val_to_signed_integer(self): @@ -286,7 +315,11 @@ def test_val_to_signed_integer(self): sim.tracer.render_trace( file=buff, renderer=self.renderer, repr_func=pyrtl.val_to_signed_integer ) - expected = " |0 |1 |2 |3 \n \ncounter --|1 |-2|-1\n" + expected = ( + " |0 |1 |2 |3 \n" + " \n" + "counter --|1 |-2|-1\n" + ) # fmt: skip self.assertEqual(buff.getvalue(), expected) def test_custom_repr_per_wire(self): diff --git a/tests/test_wire.py b/tests/test_wire.py index a076596d..9a0c6467 100644 --- a/tests/test_wire.py +++ b/tests/test_wire.py @@ -1,4 +1,5 @@ import doctest +import enum import unittest import pyrtl @@ -298,6 +299,31 @@ def test_invalid_reset_value_not_an_integer(self): pyrtl.Register(4, reset_value="hello") +class TestStateRegister(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + + def test_bitwidth(self): + class PowerOfTwoStates(enum.IntEnum): + ZERO = 0 + ONE = 1 + TWO = 2 + THREE = 3 + + state = pyrtl.StateRegister(States=PowerOfTwoStates) + self.assertEqual(state.bitwidth, 2) + + class NotPowerOfTwoStates(enum.IntEnum): + ZERO = 0 + ONE = 1 + TWO = 2 + THREE = 3 + FOUR = 4 + + state = pyrtl.StateRegister(States=NotPowerOfTwoStates) + self.assertEqual(state.bitwidth, 3) + + class TestConst(unittest.TestCase): def setUp(self): pyrtl.reset_working_block() From 05aae0d7ff0050146d3c62cfcf2ea73a96c06a32 Mon Sep 17 00:00:00 2001 From: Jeremy Lau <30300826+fdxmw@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:30:43 -0700 Subject: [PATCH 2/2] Add a `State` `IntEnum` option to `Register`'s constructor. When a `Register` is constructed with a `State`, the `Register`'s bitwidth is determined by the `IntEnum`'s maximum value, and `render_trace` displays enumeration names instead of hex values. --- docs/basic.rst | 1 - docs/index.rst | 1 - docs/regmem.rst | 5 - examples/example3-statemachine.py | 8 +- ipynb-examples/example3-statemachine.ipynb | 8 +- pyrtl/__init__.py | 3 +- pyrtl/simulation.py | 13 +- pyrtl/wire.py | 134 +++++++++------------ tests/test_simulation.py | 2 +- tests/test_wire.py | 10 +- 10 files changed, 81 insertions(+), 104 deletions(-) diff --git a/docs/basic.rst b/docs/basic.rst index 0d42a4e6..04dd40d2 100644 --- a/docs/basic.rst +++ b/docs/basic.rst @@ -23,7 +23,6 @@ hard-wired values and :class:`.Register` is how sequential elements are created pyrtl.Output pyrtl.Const pyrtl.Register - pyrtl.StateRegister :parts: 1 WireVector diff --git a/docs/index.rst b/docs/index.rst index 542ca5e7..30b3c8f4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -92,7 +92,6 @@ this is all a user needs to create a functional hardware design. pyrtl.Output pyrtl.Const pyrtl.Register - pyrtl.StateRegister :parts: 1 After specifying a hardware design, there are then options to simulate your diff --git a/docs/regmem.rst b/docs/regmem.rst index b85754cd..1e3195a7 100644 --- a/docs/regmem.rst +++ b/docs/regmem.rst @@ -9,11 +9,6 @@ Registers :show-inheritance: :special-members: __init__ -.. autoclass:: pyrtl.StateRegister - :members: - :show-inheritance: - :special-members: __init__ - Memories -------- diff --git a/examples/example3-statemachine.py b/examples/example3-statemachine.py index 11cd3816..1b535671 100644 --- a/examples/example3-statemachine.py +++ b/examples/example3-statemachine.py @@ -26,10 +26,10 @@ class State(enum.IntEnum): RFND = 5 # Issue refund. -# Define a `StateRegister`, which is just like a `Register`, except that it calculates -# the `Register`'s bitwidth from the largest possible `State`. `StateRegister`s also -# display state names in traces by default. -state = pyrtl.StateRegister(State, "state") +# Define a `Register`, that calculates its bitwidth from the largest possible `State`. +# By default, `State` names like `WAIT` will display in traces, instead of state numbers +# like `0`. +state = pyrtl.Register(name="state", States=State) # Now we could build a state machine using just the `Registers` and logic discussed in # prior examples, but doing operations **conditionally** on some input is a pretty diff --git a/ipynb-examples/example3-statemachine.ipynb b/ipynb-examples/example3-statemachine.ipynb index 52f18473..a5003693 100644 --- a/ipynb-examples/example3-statemachine.ipynb +++ b/ipynb-examples/example3-statemachine.ipynb @@ -79,9 +79,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - " Define a `StateRegister`, which is just like a [Register](https://pyrtl.readthedocs.io/en/latest/basic.html#pyrtl.Register), except that it calculates\n", - " the [Register](https://pyrtl.readthedocs.io/en/latest/basic.html#pyrtl.Register)'s bitwidth from the largest possible `State`. `StateRegister`s also\n", - " display state names in traces by default.\n" + " Define a [Register](https://pyrtl.readthedocs.io/en/latest/basic.html#pyrtl.Register), that calculates its bitwidth from the largest possible `State`.\n", + " By default, `State` names like `WAIT` will display in traces, instead of state numbers\n", + " like `0`.\n" ] }, { @@ -92,7 +92,7 @@ }, "outputs": [], "source": [ - "state = pyrtl.StateRegister(State, \"state\")\n" + "state = pyrtl.Register(name=\"state\", States=State)\n" ] }, { diff --git a/pyrtl/__init__.py b/pyrtl/__init__.py index de6d3fc8..ecb6a2b2 100644 --- a/pyrtl/__init__.py +++ b/pyrtl/__init__.py @@ -17,7 +17,7 @@ ) # convenience classes for building hardware -from .wire import WireVector, Input, Output, Const, Register, StateRegister +from .wire import WireVector, Input, Output, Const, Register from .gate_graph import GateGraph, Gate @@ -161,7 +161,6 @@ "Output", "Const", "Register", - "StateRegister", # gate_graph "GateGraph", "Gate", diff --git a/pyrtl/simulation.py b/pyrtl/simulation.py index bbf7669b..9a5346e8 100644 --- a/pyrtl/simulation.py +++ b/pyrtl/simulation.py @@ -20,7 +20,7 @@ from pyrtl.importexport import _VerilogSanitizer from pyrtl.memory import MemBlock, RomBlock from pyrtl.pyrtlexceptions import PyrtlError, PyrtlInternalError -from pyrtl.wire import Const, Input, Output, Register, StateRegister, WireVector +from pyrtl.wire import Const, Input, Output, Register, WireVector # ---------------------------------------------------------------- # __ ___ __ @@ -1106,7 +1106,7 @@ def invoke_f(f, value): if f is not None: return invoke_f(f, value) - if isinstance(wire, StateRegister): + if isinstance(wire, Register) and wire.States is not None: return invoke_f(enum_name(wire.States), value) return invoke_f(repr_func, value) @@ -1143,7 +1143,8 @@ def render_val( _prev_line* fields in RendererConstants. :param is_last: If True, current_val is in the last cycle. """ - if len(w) > 1 or w.name in repr_per_name or isinstance(w, StateRegister): + is_state_register = isinstance(w, Register) and w.States is not None + if len(w) > 1 or w.name in repr_per_name or is_state_register: # Render values in boxes for multi-bit wires ("bus"), or single-bit wires # with a specific representation. # @@ -1153,7 +1154,7 @@ def render_val( numeric_formats = [hex, oct, int, str, bin, val_to_signed_integer] flat_zero = ( w.name not in repr_per_name - and not isinstance(w, StateRegister) + and not is_state_register and repr_func in numeric_formats ) if prev_line: @@ -1988,8 +1989,8 @@ def enum_name(EnumClass: type) -> Callable[[int], str]: .. note:: - When using ``enum_name`` with a :class:`.Register`, consider using - :class:`.StateRegister` instead. + When using ``enum_name`` with a :class:`.Register`, consider constructing + :class:`.Register` with a ``State`` instead. See :meth:`.Register.__init__`. :param EnumClass: ``enum`` to convert. This is the enum class, like ``Option``, not an enum value, like ``Option.FOO`` or ``1``. diff --git a/pyrtl/wire.py b/pyrtl/wire.py index 22b325e2..43af56b2 100644 --- a/pyrtl/wire.py +++ b/pyrtl/wire.py @@ -1690,10 +1690,6 @@ class Register(WireVector): This builds a zero-initialized 2-bit counter. The second line sets the counter's value in the next cycle (``counter.next``) to the counter's value in the current cycle (``counter``), plus one. - - .. note:: - - Consider using :class:`StateRegister` for state machine ``Registers``. """ _code = "R" @@ -1811,15 +1807,44 @@ def __bool__(self): def __init__( self, - bitwidth: int, + bitwidth: int | None = None, name: str = "", reset_value: int | None = None, - block: Block = None, + block: Block | None = None, + States: type[enum.IntEnum] | None = None, ): """Construct a ``Register``. - It is an error if the ``reset_value`` cannot fit into the specified ``bitwidth`` - for this register. + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example with ``States``:: + + >>> import enum + >>> class MyStates(enum.IntEnum): + ... ZERO = 0 + ... ONE = 1 + ... TWO = 2 + ... THREE = 3 + + >>> state = pyrtl.Register( + ... name="state", States=MyStates, reset_value=MyStates.ONE + ... ) + >>> state.bitwidth + 2 + + >>> state.next <<= state + 1 + + >>> sim = pyrtl.Simulation() + >>> sim.step_multiple(nsteps=4) + >>> sim.tracer.render_trace() + + Which prints:: + + │0 │1 │2 │3 + state ONE │TWO │THREE│ZERO :param bitwidth: Number of bits to represent this ``Register``. :param name: The name of the ``Register``'s current value (``reg``, not @@ -1830,9 +1855,33 @@ def __init__( overridden at simulation time. :param block: The :class:`Block` under which the wire should be placed. Defaults to the :ref:`working_block`. + :param States: An :class:`~enum.IntEnum` defining all possible states for the + ``Register``. This should be an :class:`~enum.IntEnum` class, like + ``MyStates`` in the example above. If ``bitwidth`` is ``None``, the largest + value in the :class:`~enum.IntEnum` determines the ``Register``'s + ``bitwidth``. When ``States`` is not ``None``, + :meth:`~.SimulationTrace.render_trace` defaults to displaying enumeration + names rather than hex values. + + :raises PyrtlError: If the ``reset_value`` or ``States`` cannot fit into the + specified ``bitwidth`` for this register. """ from pyrtl.helperfuncs import infer_val_and_bitwidth + self.States = States + if States is not None: + largest_state = max(States) + inferred_bitwidth = infer_val_and_bitwidth(largest_state).bitwidth + if bitwidth is None: + bitwidth = inferred_bitwidth + + if bitwidth < inferred_bitwidth: + msg = ( + f"The largest State {largest_state.name} ({largest_state}) cannot " + f"fit in the specified {bitwidth} bits for this register" + ) + raise PyrtlError(msg) + super().__init__(bitwidth=bitwidth, name=name, block=block) self.reg_in = None # wire vector setting self.next if reset_value is not None: @@ -1876,75 +1925,6 @@ def _build(self, next): working_block().add_net(net) -class StateRegister(Register): - """A :class:`Register` containing an :class:`~enum.IntEnum` state. - - ``StateRegister`` functions identically to :class:`Register`, except that the - :class:`Register`'s bitwidth is calculated from the :class:`~enum.IntEnum`'s largest - value, and :meth:`~.SimulationTrace.render_trace` displays state names by default. - - .. doctest only:: - - >>> import pyrtl - >>> pyrtl.reset_working_block() - - Example:: - - >>> import enum - >>> class MyStates(enum.IntEnum): - ... ZERO = 0 - ... ONE = 1 - ... TWO = 2 - ... THREE = 3 - - >>> state = pyrtl.StateRegister( - ... name="state", States=MyStates, reset_value=MyStates.ONE - ... ) - >>> state.bitwidth - 2 - - >>> state.next <<= state + 1 - - >>> sim = pyrtl.Simulation() - >>> sim.step_multiple(nsteps=4) - >>> sim.tracer.render_trace() - - Which prints:: - - │0 │1 │2 │3 - state ONE │TWO │THREE│ZERO - """ - - def __init__( - self, - States: enum.IntEnum, - name: str = "", - reset_value: enum.IntEnum | int | None = None, - block: Block = None, - ): - """Constructs a :class:`Register` containing an :class:`~enum.IntEnum` value. - - :param States: An :class:`~enum.IntEnum` containing all possible states for the - ``StateRegister``. The largest value in the :class:`~enum.IntEnum` - determines the :class:`Register`'s :attr:`~WireVector.bitwidth`. - :param name: The name of the ``StateRegister``'s current value (``state``, not - ``state.next``). Must be unique. If none is provided, one will be - autogenerated. - :param reset_value: Value to initialize this ``StateRegister`` to during - simulation and in any code (e.g. Verilog) that is exported. Defaults to 0. - Can be overridden at simulation time. - :param block: The :class:`Block` under which the wire should be placed. Defaults - to the :ref:`working_block`. - """ - from pyrtl.helperfuncs import infer_val_and_bitwidth - - self.States = States - bitwidth = infer_val_and_bitwidth(max(States)).bitwidth - super().__init__( - bitwidth=bitwidth, name=name, reset_value=reset_value, block=block - ) - - class WrappedWireVector: """Wraps a ``WireVector``. Forwards all method calls and attribute accesses. diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 0c408901..35e57a0a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -292,7 +292,7 @@ class State(enum.IntEnum): C = 2 D = 3 - state = pyrtl.StateRegister(name="state", States=State, reset_value=State.B) + state = pyrtl.Register(name="state", States=State, reset_value=State.B) state.next <<= state + 1 sim = pyrtl.Simulation() sim.step_multiple(nsteps=4) diff --git a/tests/test_wire.py b/tests/test_wire.py index 9a0c6467..52511afd 100644 --- a/tests/test_wire.py +++ b/tests/test_wire.py @@ -310,19 +310,23 @@ class PowerOfTwoStates(enum.IntEnum): TWO = 2 THREE = 3 - state = pyrtl.StateRegister(States=PowerOfTwoStates) + state = pyrtl.Register(States=PowerOfTwoStates) self.assertEqual(state.bitwidth, 2) class NotPowerOfTwoStates(enum.IntEnum): ZERO = 0 ONE = 1 TWO = 2 - THREE = 3 FOUR = 4 + THREE = 3 - state = pyrtl.StateRegister(States=NotPowerOfTwoStates) + state = pyrtl.Register(States=NotPowerOfTwoStates) self.assertEqual(state.bitwidth, 3) + with self.assertRaises(pyrtl.PyrtlError): + # Bitwidth 1 is too small to fit PowerOfTwoStates.THREE. + state = pyrtl.Register(bitwidth=1, States=PowerOfTwoStates) + class TestConst(unittest.TestCase): def setUp(self):