diff --git a/examples/Makefile b/examples/Makefile index 5246dfce..3f0dd36f 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -1,4 +1,4 @@ -PYTHON=python +PYTHON=uv run python3 PY_FILES=$(wildcard *.py) IPYNB_FILES=$(addprefix ../ipynb-examples/, $(PY_FILES:.py=.ipynb)) diff --git a/examples/example1.2-wire-struct.py b/examples/example1.2-wire-struct.py index 66c84932..e16bcd9e 100644 --- a/examples/example1.2-wire-struct.py +++ b/examples/example1.2-wire-struct.py @@ -154,7 +154,7 @@ class Pixel: # `PixelPair` is a pair of two `Pixels`. This also shows how `wire_matrix()` and # `wire_struct` work together - `PixelPair` is a `wire_struct` nested in a # `wire_matrix()`. -PixelPair = pyrtl.wire_matrix(component_schema=Pixel, size=2) +PixelPair = pyrtl.wire_matrix(component_schema=Pixel, size=2, class_name="PixelPair") # `wire_matrix()` returns a class! assert inspect.isclass(PixelPair) diff --git a/ipynb-examples/example1.2-wire-struct.ipynb b/ipynb-examples/example1.2-wire-struct.ipynb index 95345c4f..f1bd3e21 100644 --- a/ipynb-examples/example1.2-wire-struct.ipynb +++ b/ipynb-examples/example1.2-wire-struct.ipynb @@ -384,7 +384,7 @@ }, "outputs": [], "source": [ - "PixelPair = pyrtl.wire_matrix(component_schema=Pixel, size=2)\n" + "PixelPair = pyrtl.wire_matrix(component_schema=Pixel, size=2, class_name=\"PixelPair\")\n" ] }, { diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index 055510e1..0e710566 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -1470,14 +1470,14 @@ def wire_struct(wire_struct_spec): 1. Provide a driver for *each* component wire, for example:: - >>> byte = Byte(high=0xA, low=0xB) + >>> byte = Byte(high=0xA, low=0xB) Note how the component names (``high``, ``low``) are used as keyword args for the constructor. Drivers must be provided for *all* components. 2. Provide a driver for the entire ``@wire_struct``, for example:: - >>> byte = Byte(Byte=0xAB) + >>> byte = Byte(Byte=0xAB) Note how the class name (``Byte``) is used as a keyword arg for the constructor. @@ -1576,7 +1576,7 @@ class Pixel: ``@wire_struct`` can be composed with :func:`wire_matrix`:: - Word = pyrtl.wire_matrix(component_schema=8, size=4) + Word = pyrtl.wire_matrix(component_schema=8, size=4, class_name="Word") @pyrtl.wire_struct class CacheLine: @@ -1615,6 +1615,41 @@ class CacheLine: No values are specified for ``input_byte`` because its value is not known until simulation time. + + Generic Usage + ------------- + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Functions can work with ``@wire_struct`` instances generically. For example, we can + define a function that accepts any ``@wire_struct`` and returns the same type of + ``@wire_struct``, with all of the argument's fields bitwise-inverted:: + + >>> def invert_all_fields(output_name, any_wire_struct): + ... # Retrieve the argument wire_struct's class. + ... OutputClass = type(any_wire_struct) + ... # Instantiate that class, specifying its full value. + ... return OutputClass(name=output_name, _value=~any_wire_struct) + + >>> input_byte = Byte(name="input_byte", concatenated_type=pyrtl.Input) + >>> output_byte = invert_all_fields("output_byte", input_byte) + + >>> sim = pyrtl.Simulation() + >>> sim.step({"input_byte": 0}) + >>> hex(sim.inspect("output_byte")) + '0xff' + + ``invert_all_fields`` uses ``type(any_wire_struct)`` to retrieve the argument + ``@wire_struct``'s class, and instantiates that class for the function's return + value. + + This uses the special constructor ``kwarg`` ``_value``, rather than the name of the + class, to specify the full value for the returned ``@wire_struct`` object. In this + example, we must use ``_value`` instead of the name of the class (``Byte``) because + ``any_wire_struct`` might not be a ``Byte``. """ # Convert the decorated class' annotations (dict of attr_name: attr_value) # to a list of _ComponentMetas. @@ -1651,164 +1686,167 @@ class CacheLine: # Name of the decorated class. class_name = wire_struct_spec.__name__ - class _WireStruct(WrappedWireVector): - """``wire_struct`` implementation: Concatenate or slice :class:`WireVector`. + def _init( + self, + name="", + block=None, + concatenated_type=WireVector, + component_type=WireVector, + **kwargs, + ): + """Concatenate or slice :class:`WireVector` components. - ``wire_struct`` works by either concatenating component :class:`WireVector` - to create the ``wire_struct``'s full value, *or* slicing a - ``wire_struct``s value to create component :class:`WireVectors`. A - ``wire_struct`` can only concatenate or slice, not both. The decision - to concatenate or slice is made in __init__. + The remaining keyword args specify values for all wires. If the concatenated + value is provided, its value must be provided with the keyword arg matching + the decorated class name. For example, if the decorated class is:: - """ + @wire_struct + class Byte: + high: 4 # high is the 4 most significant bits. + low: 4 # low is the 4 least significant bits. - _bitwidth = total_bitwidth - _class_name = class_name - _is_wire_struct = True - - def __init__( - self, - name="", - block=None, - concatenated_type=WireVector, - component_type=WireVector, - **kwargs, - ): - """Concatenate or slice :class:`WireVector` components. - - The remaining keyword args specify values for all wires. If the concatenated - value is provided, its value must be provided with the keyword arg matching - the decorated class name. For example, if the decorated class is:: - - @wire_struct - class Byte: - high: 4 # high is the 4 most significant bits. - low: 4 # low is the 4 least significant bits. - - then the concatenated value must be provided like this:: - - byte = Byte(Byte=0xAB) - - And if the component values are provided instead, their values are set by - keyword args matching the component names:: - - byte = Byte(low=0xA, high=0xB) - - :param str name: The name of the concatenated wire. Must be unique. If none - is provided, one will be autogenerated. If a name is provided, - components will be assigned names of the form "{name}.{component_name}". - :param Block block: The block containing the concatenated and component - wires. Defaults to the :ref:`working_block`. - :param type concatenated_type: Type for the concatenated - :class:`WireVector`. - :param type component_type: Type for each component. - """ - # The concatenated WireVector contains all the _WireStruct's wires. - # WrappedWireVector (base class) will forward all attribute and method - # accesses on this _WireStruct to the concatenated WireVector. - if ( - class_name in kwargs - and isinstance(kwargs[class_name], int) - and concatenated_type is WireVector - ): - # Special case: simplify the concatenated type to Const. - concatenated = Const( - bitwidth=self._bitwidth, - name=name, - block=block, - val=kwargs[class_name], - ) - else: - concatenated = concatenated_type( - bitwidth=self._bitwidth, name=name, block=block - ) - super().__init__(wire=concatenated) - - # self._components maps from component name to each component's WireVector. - components = {} - self.__dict__["_components"] = components - - # Handle Input and Register special cases. - if concatenated_type is Input or concatenated_type is Register: - kwargs = {class_name: None} - elif component_type is Input or component_type is Register: - kwargs = {component_meta.name: None for component_meta in schema} - - if class_name in kwargs: - # Check for unused kwargs. - for component_name in kwargs: - if component_name != class_name: - msg = ( - "Do not pass additional kwargs to @wire_struct when " - f'slicing. ("{class_name}" was passed so don\'t pass ' - f'"{component_name}")' - ) - raise PyrtlError(msg) - # Concatenated value was provided. Slice it into components. - _slice( - block=block, - schema=schema, - bitwidth=self._bitwidth, - component_type=component_type, - name=name, - concatenated=concatenated, - components=components, - concatenated_value=kwargs[class_name], - ) - else: - # Component values were provided; concatenate them. - # Check that values were provided for all components. - expected_component_names = [ - component_meta.name for component_meta in schema - ] - for expected_component_name in expected_component_names: - if expected_component_name not in kwargs: - msg = ( - "You must provide kwargs for all @wire_struct components " - "when concatenating (missing kwarg " - f'"{expected_component_name}")' - ) - raise PyrtlError(msg) - # Check for unused kwargs. - for component_name in kwargs: - if component_name not in expected_component_names: - msg = ( - "Do not pass additional kwargs to @wire_struct when " - f'concatenating (don\'t pass "{component_name}")' - ) - raise PyrtlError(msg) - - _concatenate( - block=block, - schema=schema, - component_type=component_type, - name=name, - concatenated=concatenated, - components=components, - component_map=kwargs, - ) + then the concatenated value must be provided like this:: - def __getattr__(self, component_name: str): - """Retrieve a component by name. + byte = Byte(Byte=0xAB) + + And if the component values are provided instead, their values are set by + keyword args matching the component names:: + + byte = Byte(low=0xA, high=0xB) + + :param str name: The name of the concatenated wire. Must be unique. If none + is provided, one will be autogenerated. If a name is provided, + components will be assigned names of the form "{name}.{component_name}". + :param Block block: The block containing the concatenated and component + wires. Defaults to the :ref:`working_block`. + :param type concatenated_type: Type for the concatenated + :class:`WireVector`. + :param type component_type: Type for each component. + """ + # Special case: Replace a "_value" kwarg with the actual class name. This + # kwarg is useful when the class_name is not statically known. + if "_value" in kwargs: + kwargs[class_name] = kwargs["_value"] + del kwargs["_value"] + + # The concatenated WireVector contains all the _WireStruct's wires. + # WrappedWireVector (base class) will forward all attribute and method + # accesses on this _WireStruct to the concatenated WireVector. + if ( + class_name in kwargs + and isinstance(kwargs[class_name], int) + and concatenated_type is WireVector + ): + # Special case: simplify the concatenated type to Const. + concatenated = Const( + bitwidth=self._bitwidth, + name=name, + block=block, + val=kwargs[class_name], + ) + else: + concatenated = concatenated_type( + bitwidth=self._bitwidth, name=name, block=block + ) + WrappedWireVector.__init__(self, wire=concatenated) + + # self._components maps from component name to each component's WireVector. + components = {} + self.__dict__["_components"] = components + + # Handle Input and Register special cases. + if concatenated_type is Input or concatenated_type is Register: + kwargs = {class_name: None} + elif component_type is Input or component_type is Register: + kwargs = {component_meta.name: None for component_meta in schema} + + if class_name in kwargs: + # Check for unused kwargs. + for component_name in kwargs: + if component_name != class_name: + msg = ( + "Do not pass additional kwargs to @wire_struct when " + f'slicing. ("{class_name}" was passed so don\'t pass ' + f'"{component_name}")' + ) + raise PyrtlError(msg) + # Concatenated value was provided. Slice it into components. + _slice( + block=block, + schema=schema, + bitwidth=self._bitwidth, + component_type=component_type, + name=name, + concatenated=concatenated, + components=components, + concatenated_value=kwargs[class_name], + ) + else: + # Component values were provided; concatenate them. + # Check that values were provided for all components. + expected_component_names = [ + component_meta.name for component_meta in schema + ] + for expected_component_name in expected_component_names: + if expected_component_name not in kwargs: + msg = ( + "You must provide kwargs for all @wire_struct components " + "when concatenating (missing kwarg " + f'"{expected_component_name}")' + ) + raise PyrtlError(msg) + # Check for unused kwargs. + for component_name in kwargs: + if component_name not in expected_component_names: + msg = ( + "Do not pass additional kwargs to @wire_struct when " + f'concatenating (don\'t pass "{component_name}")' + ) + raise PyrtlError(msg) - Components are concatenated to form the concatenated :class:`WireVector`, or - sliced from the concatenated :class:`WireVector`. + _concatenate( + block=block, + schema=schema, + component_type=component_type, + name=name, + concatenated=concatenated, + components=components, + component_map=kwargs, + ) - :param component_name: The name of the component wire. - """ - components = self.__dict__["_components"] - if component_name in components: - return components[component_name] - return super().__getattr__(component_name) + def _getattr(self, component_name: str): + """Retrieve a component by name. - def __len__(self): - components = self.__dict__["_components"] - return len(components) + Components are concatenated to form the concatenated :class:`WireVector`, or + sliced from the concatenated :class:`WireVector`. - return _WireStruct + :param component_name: The name of the component wire. + """ + components = self.__dict__["_components"] + if component_name in components: + return components[component_name] + return WrappedWireVector.__getattr__(self, component_name) + + def _len(self): + components = self.__dict__["_components"] + return len(components) + + return type( + class_name, + (WrappedWireVector,), + { + "_bitwidth": total_bitwidth, + "_class_name": class_name, + "_is_wire_struct": True, + "__init__": _init, + "__getattr__": _getattr, + "__len__": _len, + "__doc__": wire_struct_spec.__doc__, + }, + ) -def wire_matrix(component_schema, size: int): +def wire_matrix(component_schema, size: int, class_name: str | None = None): """Returns a class that assigns numbered indices to :class:`WireVector` slices. ``wire_matrix`` assigns numbered indices to *non-overlapping* :class:`WireVector` @@ -1829,7 +1867,7 @@ def wire_matrix(component_schema, size: int): An example 32-bit ``Word`` ``wire_matrix``, which represents a group of four bytes, can be defined as:: - >>> Word = wire_matrix(component_schema=8, size=4) + >>> Word = wire_matrix(component_schema=8, size=4, class_name="Word") .. NOTE:: @@ -1885,8 +1923,8 @@ def wire_matrix(component_schema, size: int): ``wire_matrix`` can be composed with itself and :func:`wire_struct`. For example, we can define some multi-dimensional byte arrays:: - Array1D = wire_matrix(component_schema=8, size=2) - Array2D = wire_matrix(component_schema=Array1D, size=2) + Array1D = wire_matrix(component_schema=8, size=2, class_name="Array1D") + Array2D = wire_matrix(component_schema=Array1D, size=2, class_name="Array2D") Drivers must be specified for all components, but they can be specified at any level. All these examples construct an equivalent ``wire_matrix``:: @@ -1912,7 +1950,7 @@ def wire_matrix(component_schema, size: int): class Byte: high: 4 low: 4 - Array1D = wire_matrix(component_schema=Byte, size=2) + Array1D = wire_matrix(component_schema=Byte, size=2, class_name="Array1D") array_1d = Array1D(values=[0xAB, 0xCD]) print(array_1d[0].high.bitwidth) # Prints 4. @@ -1948,6 +1986,10 @@ class Byte: No values are specified for ``input_word`` because its value is not known until simulation time. """ + # Users may specify the name of the returned class. + if class_name is None: + class_name = "_WireMatrix" + # Determine each component's bitwidth. if hasattr(component_schema, "_is_wire_struct") or hasattr( component_schema, "_is_wire_matrix" @@ -1957,116 +1999,121 @@ class Byte: component_bitwidth = component_schema component_schema = None - class _WireMatrix(WrappedWireVector): - _component_bitwidth = component_bitwidth - _component_schema = component_schema - _size = size - - _bitwidth = component_bitwidth * size - _is_wire_matrix = True - - def __init__( - self, - name: str = "", - block: Block = None, - concatenated_type=WireVector, - component_type=WireVector, - values: list | None = None, + def _init( + self, + name: str = "", + block: Block = None, + concatenated_type=WireVector, + component_type=WireVector, + values: list | None = None, + ): + # The concatenated WireVector contains all the _WireMatrix's wires. + # WrappedWireVector (base class) will forward all attribute and method + # accesses on this _WireMatrix to the concatenated WireVector. + if values is None: + values = [] + if ( + len(values) == 1 + and isinstance(values[0], int) + and concatenated_type is WireVector ): - # The concatenated WireVector contains all the _WireMatrix's wires. - # WrappedWireVector (base class) will forward all attribute and method - # accesses on this _WireMatrix to the concatenated WireVector. - if values is None: - values = [] - if ( - len(values) == 1 - and isinstance(values[0], int) - and concatenated_type is WireVector - ): - # Special case: simplify the concatenated type to Const. - concatenated = Const( - bitwidth=self._bitwidth, name=name, block=block, val=values[0] - ) - else: - concatenated = concatenated_type( - bitwidth=self._bitwidth, name=name, block=block - ) - super().__init__(wire=concatenated) - - schema = [] - for component_name in range(self._size): - schema.append( - _ComponentMeta( - name=component_name, - bitwidth=self._component_bitwidth, - type=component_schema, - ) - ) + # Special case: simplify the concatenated type to Const. + concatenated = Const( + bitwidth=self._bitwidth, name=name, block=block, val=values[0] + ) + else: + concatenated = concatenated_type( + bitwidth=self._bitwidth, name=name, block=block + ) + WrappedWireVector.__init__(self, wire=concatenated) - # By default, slice the concatenated value into components iff exactly one - # value is provided. - slicing = len(values) == 1 - - # Handle Input and Register special cases. - if concatenated_type is Input or concatenated_type is Register: - # Slice the concatenated value. Override the default 'slicing' because - # 'values' is empty when slicing a concatenated Input or Register. - # - # Note that we can't just check len(values) == 1 after we set values to - # [None] because that doesn't work when there is only one element in the - # wire_matrix. We must distinguish between: - # - # 1. Slicing values to produce values[0] (this case). - # 2. Concatenating values[0] to produce values (next case). - # - # But len(values) == 1 in both cases. The slice in (1) and concatenate - # in (2) are both no-ops, but we have to get the direction right. In the - # first case, values[0] is driven by values, and in the second case, - # values is driven by values[0]. - slicing = True - values = [None] - elif component_type is Input or component_type is Register: - values = [None for _ in range(self._size)] - - self._components = [None for i in range(len(schema))] - if slicing: - # Concatenated value was provided. Slice it into components. - _slice( - block=block, - schema=schema, - bitwidth=self._bitwidth, - component_type=component_type, - name=name, - concatenated=concatenated, - components=self._components, - concatenated_value=values[0], - ) - else: - if len(values) != len(schema): - msg = ( - "wire_matrix constructor expects 1 value to slice, or " - f"{len(schema)} values to concatenate (received {len(values)} " - "values)" - ) - raise PyrtlError(msg) - # Component values were provided; concatenate them. - _concatenate( - block=block, - schema=schema, - component_type=component_type, - name=name, - concatenated=concatenated, - components=self._components, - component_map=values, + schema = [] + for component_name in range(self._size): + schema.append( + _ComponentMeta( + name=component_name, + bitwidth=self._component_bitwidth, + type=component_schema, ) + ) - def __getitem__(self, key): - return self._components[key] - - def __len__(self): - return len(self._components) + # By default, slice the concatenated value into components iff exactly one + # value is provided. + slicing = len(values) == 1 + + # Handle Input and Register special cases. + if concatenated_type is Input or concatenated_type is Register: + # Slice the concatenated value. Override the default 'slicing' because + # 'values' is empty when slicing a concatenated Input or Register. + # + # Note that we can't just check len(values) == 1 after we set values to + # [None] because that doesn't work when there is only one element in the + # wire_matrix. We must distinguish between: + # + # 1. Slicing values to produce values[0] (this case). + # 2. Concatenating values[0] to produce values (next case). + # + # But len(values) == 1 in both cases. The slice in (1) and concatenate + # in (2) are both no-ops, but we have to get the direction right. In the + # first case, values[0] is driven by values, and in the second case, + # values is driven by values[0]. + slicing = True + values = [None] + elif component_type is Input or component_type is Register: + values = [None for _ in range(self._size)] + + self._components = [None for i in range(len(schema))] + if slicing: + # Concatenated value was provided. Slice it into components. + _slice( + block=block, + schema=schema, + bitwidth=self._bitwidth, + component_type=component_type, + name=name, + concatenated=concatenated, + components=self._components, + concatenated_value=values[0], + ) + else: + if len(values) != len(schema): + msg = ( + "wire_matrix constructor expects 1 value to slice, or " + f"{len(schema)} values to concatenate (received {len(values)} " + "values)" + ) + raise PyrtlError(msg) + # Component values were provided; concatenate them. + _concatenate( + block=block, + schema=schema, + component_type=component_type, + name=name, + concatenated=concatenated, + components=self._components, + component_map=values, + ) - return _WireMatrix + def _getitem(self, key): + return self._components[key] + + def _len(self): + return len(self._components) + + return type( + class_name, + (WrappedWireVector,), + { + "_component_bitwidth": component_bitwidth, + "_component_schema": component_schema, + "_size": size, + "_bitwidth": component_bitwidth * size, + "_is_wire_matrix": True, + "__init__": _init, + "__getitem__": _getitem, + "__len__": _len, + }, + ) def one_hot_to_binary(w: WireVectorLike) -> WireVector: diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 3ad5bfa6..06ff8e01 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -1350,6 +1350,8 @@ def test_concatenate(self): """Drive high and low, observe concatenated high+low.""" # Concatenates to 'byte'. byte = Byte(name="byte", high=0xA, low=0xB) + self.assertTrue(isinstance(byte, Byte)) + self.assertEqual(Byte.__name__, "Byte") self.assertEqual(len(byte), 2) self.assertEqual(byte.bitwidth, 8) self.assertEqual(len(pyrtl.as_wires(byte)), 8) @@ -1394,6 +1396,21 @@ def test_slice(self): self.assertTrue(isinstance(byte.low, pyrtl.Const)) self.assertEqual(byte.low.val, 0xD) + def test_underscore_value_slice(self): + """Drive concatenated high+low by setting the special _value name, observe high + and low. + """ + # Slices to 'byte.high' and 'byte.low'. + byte = Byte(name="byte", _value=0xCD) + + # Constants are sliced immediately. + self.assertTrue(isinstance(pyrtl.as_wires(byte), pyrtl.Const)) + self.assertEqual(byte.val, 0xCD) + self.assertTrue(isinstance(byte.high, pyrtl.Const)) + self.assertEqual(byte.high.val, 0xC) + self.assertTrue(isinstance(byte.low, pyrtl.Const)) + self.assertEqual(byte.low.val, 0xD) + def test_input_slice(self): """Given Input concatenated high+low, observe high and low.""" byte = Byte(name="byte", concatenated_type=pyrtl.Input) @@ -1588,21 +1605,21 @@ def test_anonymous_pixel_concatenate(self): # BitPair is an array of two single wires. -BitPair = pyrtl.wire_matrix(component_schema=1, size=2) +BitPair = pyrtl.wire_matrix(component_schema=1, size=2, class_name="BitPair") # Word is an array of two Bytes. This checks that a @wire_struct (Byte) can be a # component of a wire_matrix (Word). -Word = pyrtl.wire_matrix(component_schema=Byte, size=2) +Word = pyrtl.wire_matrix(component_schema=Byte, size=2, class_name="Word") # ByteMatrix tests the corner case of a single-element wire_matrix. -ByteMatrix = pyrtl.wire_matrix(component_schema=Byte, size=1) +ByteMatrix = pyrtl.wire_matrix(component_schema=Byte, size=1, class_name="ByteMatrix") # DWord is an array of two Words, or effectively a 2x2 array of Bytes. This checks that # a wire_matrix (Word) can be a component of a wire_matrix (DWord). -DWord = pyrtl.wire_matrix(component_schema=Word, size=2) +DWord = pyrtl.wire_matrix(component_schema=Word, size=2, class_name="DWord") # CachedData is valid bit paired with some data. This checks that a wire_matrix (Word) @@ -1619,6 +1636,8 @@ def setUp(self): def test_wire_matrix_slice(self): bitpair = BitPair(name="bitpair", values=[2]) + self.assertTrue(isinstance(bitpair, BitPair)) + self.assertEqual(BitPair.__name__, "BitPair") self.assertEqual(len(bitpair), 2) self.assertEqual(bitpair.bitwidth, 2) self.assertEqual(len(pyrtl.as_wires(bitpair)), 2)