diff --git a/docs/rtllib.rst b/docs/rtllib.rst index 86317128..6b8204f7 100644 --- a/docs/rtllib.rst +++ b/docs/rtllib.rst @@ -23,6 +23,12 @@ Multipliers .. automodule:: pyrtl.rtllib.multipliers :members: +Floating Point +-------------- + +.. automodule:: pyrtl.rtllib.float + :members: + Barrel Shifter -------------- diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index 055510e1..a7023617 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -1481,6 +1481,11 @@ def wire_struct(wire_struct_spec): Note how the class name (``Byte``) is used as a keyword arg for the constructor. + If the class name is not known, the special name ``_value`` can be used instead:: + + >>> UnknownDynamicType = Byte + >>> unknown_dynamic_type = UnknownDynamicType(_value=0xAB) + Accessing Slices ---------------- @@ -1651,164 +1656,167 @@ class CacheLine: # Name of the decorated class. class_name = wire_struct_spec.__name__ - class _WireStruct(WrappedWireVector): - """``wire_struct`` implementation: Concatenate or slice :class:`WireVector`. + def _init( + self, + name="", + block=None, + concatenated_type=WireVector, + component_type=WireVector, + **kwargs, + ): + """Concatenate or slice :class:`WireVector` components. - ``wire_struct`` works by either concatenating component :class:`WireVector` - to create the ``wire_struct``'s full value, *or* slicing a - ``wire_struct``s value to create component :class:`WireVectors`. A - ``wire_struct`` can only concatenate or slice, not both. The decision - to concatenate or slice is made in __init__. + The remaining keyword args specify values for all wires. If the concatenated + value is provided, its value must be provided with the keyword arg matching + the decorated class name. For example, if the decorated class is:: - """ + @wire_struct + class Byte: + high: 4 # high is the 4 most significant bits. + low: 4 # low is the 4 least significant bits. - _bitwidth = total_bitwidth - _class_name = class_name - _is_wire_struct = True - - def __init__( - self, - name="", - block=None, - concatenated_type=WireVector, - component_type=WireVector, - **kwargs, - ): - """Concatenate or slice :class:`WireVector` components. - - The remaining keyword args specify values for all wires. If the concatenated - value is provided, its value must be provided with the keyword arg matching - the decorated class name. For example, if the decorated class is:: - - @wire_struct - class Byte: - high: 4 # high is the 4 most significant bits. - low: 4 # low is the 4 least significant bits. - - then the concatenated value must be provided like this:: - - byte = Byte(Byte=0xAB) - - And if the component values are provided instead, their values are set by - keyword args matching the component names:: - - byte = Byte(low=0xA, high=0xB) - - :param str name: The name of the concatenated wire. Must be unique. If none - is provided, one will be autogenerated. If a name is provided, - components will be assigned names of the form "{name}.{component_name}". - :param Block block: The block containing the concatenated and component - wires. Defaults to the :ref:`working_block`. - :param type concatenated_type: Type for the concatenated - :class:`WireVector`. - :param type component_type: Type for each component. - """ - # The concatenated WireVector contains all the _WireStruct's wires. - # WrappedWireVector (base class) will forward all attribute and method - # accesses on this _WireStruct to the concatenated WireVector. - if ( - class_name in kwargs - and isinstance(kwargs[class_name], int) - and concatenated_type is WireVector - ): - # Special case: simplify the concatenated type to Const. - concatenated = Const( - bitwidth=self._bitwidth, - name=name, - block=block, - val=kwargs[class_name], - ) - else: - concatenated = concatenated_type( - bitwidth=self._bitwidth, name=name, block=block - ) - super().__init__(wire=concatenated) - - # self._components maps from component name to each component's WireVector. - components = {} - self.__dict__["_components"] = components - - # Handle Input and Register special cases. - if concatenated_type is Input or concatenated_type is Register: - kwargs = {class_name: None} - elif component_type is Input or component_type is Register: - kwargs = {component_meta.name: None for component_meta in schema} - - if class_name in kwargs: - # Check for unused kwargs. - for component_name in kwargs: - if component_name != class_name: - msg = ( - "Do not pass additional kwargs to @wire_struct when " - f'slicing. ("{class_name}" was passed so don\'t pass ' - f'"{component_name}")' - ) - raise PyrtlError(msg) - # Concatenated value was provided. Slice it into components. - _slice( - block=block, - schema=schema, - bitwidth=self._bitwidth, - component_type=component_type, - name=name, - concatenated=concatenated, - components=components, - concatenated_value=kwargs[class_name], - ) - else: - # Component values were provided; concatenate them. - # Check that values were provided for all components. - expected_component_names = [ - component_meta.name for component_meta in schema - ] - for expected_component_name in expected_component_names: - if expected_component_name not in kwargs: - msg = ( - "You must provide kwargs for all @wire_struct components " - "when concatenating (missing kwarg " - f'"{expected_component_name}")' - ) - raise PyrtlError(msg) - # Check for unused kwargs. - for component_name in kwargs: - if component_name not in expected_component_names: - msg = ( - "Do not pass additional kwargs to @wire_struct when " - f'concatenating (don\'t pass "{component_name}")' - ) - raise PyrtlError(msg) - - _concatenate( - block=block, - schema=schema, - component_type=component_type, - name=name, - concatenated=concatenated, - components=components, - component_map=kwargs, - ) + then the concatenated value must be provided like this:: - def __getattr__(self, component_name: str): - """Retrieve a component by name. + byte = Byte(Byte=0xAB) - Components are concatenated to form the concatenated :class:`WireVector`, or - sliced from the concatenated :class:`WireVector`. + And if the component values are provided instead, their values are set by + keyword args matching the component names:: - :param component_name: The name of the component wire. - """ - components = self.__dict__["_components"] - if component_name in components: - return components[component_name] - return super().__getattr__(component_name) + byte = Byte(low=0xA, high=0xB) - def __len__(self): - components = self.__dict__["_components"] - return len(components) + :param str name: The name of the concatenated wire. Must be unique. If none + is provided, one will be autogenerated. If a name is provided, + components will be assigned names of the form "{name}.{component_name}". + :param Block block: The block containing the concatenated and component + wires. Defaults to the :ref:`working_block`. + :param type concatenated_type: Type for the concatenated + :class:`WireVector`. + :param type component_type: Type for each component. + """ + # Special case: Replace a "_value" kwarg with the actual class name. This + # kwarg is useful when the class_name is not statically known. + if "_value" in kwargs: + kwargs[class_name] = kwargs["_value"] + del kwargs["_value"] + + # The concatenated WireVector contains all the _WireStruct's wires. + # WrappedWireVector (base class) will forward all attribute and method + # accesses on this _WireStruct to the concatenated WireVector. + if ( + class_name in kwargs + and isinstance(kwargs[class_name], int) + and concatenated_type is WireVector + ): + # Special case: simplify the concatenated type to Const. + concatenated = Const( + bitwidth=self._bitwidth, + name=name, + block=block, + val=kwargs[class_name], + ) + else: + concatenated = concatenated_type( + bitwidth=self._bitwidth, name=name, block=block + ) + WrappedWireVector.__init__(self, wire=concatenated) + + # self._components maps from component name to each component's WireVector. + components = {} + self.__dict__["_components"] = components + + # Handle Input and Register special cases. + if concatenated_type is Input or concatenated_type is Register: + kwargs = {class_name: None} + elif component_type is Input or component_type is Register: + kwargs = {component_meta.name: None for component_meta in schema} + + if class_name in kwargs: + # Check for unused kwargs. + for component_name in kwargs: + if component_name != class_name: + msg = ( + "Do not pass additional kwargs to @wire_struct when " + f'slicing. ("{class_name}" was passed so don\'t pass ' + f'"{component_name}")' + ) + raise PyrtlError(msg) + # Concatenated value was provided. Slice it into components. + _slice( + block=block, + schema=schema, + bitwidth=self._bitwidth, + component_type=component_type, + name=name, + concatenated=concatenated, + components=components, + concatenated_value=kwargs[class_name], + ) + else: + # Component values were provided; concatenate them. + # Check that values were provided for all components. + expected_component_names = [ + component_meta.name for component_meta in schema + ] + for expected_component_name in expected_component_names: + if expected_component_name not in kwargs: + msg = ( + "You must provide kwargs for all @wire_struct components " + "when concatenating (missing kwarg " + f'"{expected_component_name}")' + ) + raise PyrtlError(msg) + # Check for unused kwargs. + for component_name in kwargs: + if component_name not in expected_component_names: + msg = ( + "Do not pass additional kwargs to @wire_struct when " + f'concatenating (don\'t pass "{component_name}")' + ) + raise PyrtlError(msg) - return _WireStruct + _concatenate( + block=block, + schema=schema, + component_type=component_type, + name=name, + concatenated=concatenated, + components=components, + component_map=kwargs, + ) + + def _getattr(self, component_name: str): + """Retrieve a component by name. + Components are concatenated to form the concatenated :class:`WireVector`, or + sliced from the concatenated :class:`WireVector`. -def wire_matrix(component_schema, size: int): + :param component_name: The name of the component wire. + """ + components = self.__dict__["_components"] + if component_name in components: + return components[component_name] + return WrappedWireVector.__getattr__(self, component_name) + + def _len(self): + components = self.__dict__["_components"] + return len(components) + + return type( + class_name, + (WrappedWireVector,), + { + "_bitwidth": total_bitwidth, + "_class_name": class_name, + "_is_wire_struct": True, + "__init__": _init, + "__getattr__": _getattr, + "__len__": _len, + "__doc__": wire_struct_spec.__doc__, + }, + ) + + +def wire_matrix(component_schema, size: int, class_name: str | None = None): """Returns a class that assigns numbered indices to :class:`WireVector` slices. ``wire_matrix`` assigns numbered indices to *non-overlapping* :class:`WireVector` @@ -1948,6 +1956,9 @@ class Byte: No values are specified for ``input_word`` because its value is not known until simulation time. """ + if class_name is None: + class_name = "_WireMatrix" + # Determine each component's bitwidth. if hasattr(component_schema, "_is_wire_struct") or hasattr( component_schema, "_is_wire_matrix" @@ -1957,116 +1968,121 @@ class Byte: component_bitwidth = component_schema component_schema = None - class _WireMatrix(WrappedWireVector): - _component_bitwidth = component_bitwidth - _component_schema = component_schema - _size = size - - _bitwidth = component_bitwidth * size - _is_wire_matrix = True - - def __init__( - self, - name: str = "", - block: Block = None, - concatenated_type=WireVector, - component_type=WireVector, - values: list | None = None, + def _init( + self, + name: str = "", + block: Block = None, + concatenated_type=WireVector, + component_type=WireVector, + values: list | None = None, + ): + # The concatenated WireVector contains all the _WireMatrix's wires. + # WrappedWireVector (base class) will forward all attribute and method + # accesses on this _WireMatrix to the concatenated WireVector. + if values is None: + values = [] + if ( + len(values) == 1 + and isinstance(values[0], int) + and concatenated_type is WireVector ): - # The concatenated WireVector contains all the _WireMatrix's wires. - # WrappedWireVector (base class) will forward all attribute and method - # accesses on this _WireMatrix to the concatenated WireVector. - if values is None: - values = [] - if ( - len(values) == 1 - and isinstance(values[0], int) - and concatenated_type is WireVector - ): - # Special case: simplify the concatenated type to Const. - concatenated = Const( - bitwidth=self._bitwidth, name=name, block=block, val=values[0] - ) - else: - concatenated = concatenated_type( - bitwidth=self._bitwidth, name=name, block=block - ) - super().__init__(wire=concatenated) - - schema = [] - for component_name in range(self._size): - schema.append( - _ComponentMeta( - name=component_name, - bitwidth=self._component_bitwidth, - type=component_schema, - ) - ) + # Special case: simplify the concatenated type to Const. + concatenated = Const( + bitwidth=self._bitwidth, name=name, block=block, val=values[0] + ) + else: + concatenated = concatenated_type( + bitwidth=self._bitwidth, name=name, block=block + ) + WrappedWireVector.__init__(self, wire=concatenated) - # By default, slice the concatenated value into components iff exactly one - # value is provided. - slicing = len(values) == 1 - - # Handle Input and Register special cases. - if concatenated_type is Input or concatenated_type is Register: - # Slice the concatenated value. Override the default 'slicing' because - # 'values' is empty when slicing a concatenated Input or Register. - # - # Note that we can't just check len(values) == 1 after we set values to - # [None] because that doesn't work when there is only one element in the - # wire_matrix. We must distinguish between: - # - # 1. Slicing values to produce values[0] (this case). - # 2. Concatenating values[0] to produce values (next case). - # - # But len(values) == 1 in both cases. The slice in (1) and concatenate - # in (2) are both no-ops, but we have to get the direction right. In the - # first case, values[0] is driven by values, and in the second case, - # values is driven by values[0]. - slicing = True - values = [None] - elif component_type is Input or component_type is Register: - values = [None for _ in range(self._size)] - - self._components = [None for i in range(len(schema))] - if slicing: - # Concatenated value was provided. Slice it into components. - _slice( - block=block, - schema=schema, - bitwidth=self._bitwidth, - component_type=component_type, - name=name, - concatenated=concatenated, - components=self._components, - concatenated_value=values[0], - ) - else: - if len(values) != len(schema): - msg = ( - "wire_matrix constructor expects 1 value to slice, or " - f"{len(schema)} values to concatenate (received {len(values)} " - "values)" - ) - raise PyrtlError(msg) - # Component values were provided; concatenate them. - _concatenate( - block=block, - schema=schema, - component_type=component_type, - name=name, - concatenated=concatenated, - components=self._components, - component_map=values, + schema = [] + for component_name in range(self._size): + schema.append( + _ComponentMeta( + name=component_name, + bitwidth=self._component_bitwidth, + type=component_schema, ) + ) - def __getitem__(self, key): - return self._components[key] - - def __len__(self): - return len(self._components) + # By default, slice the concatenated value into components iff exactly one + # value is provided. + slicing = len(values) == 1 + + # Handle Input and Register special cases. + if concatenated_type is Input or concatenated_type is Register: + # Slice the concatenated value. Override the default 'slicing' because + # 'values' is empty when slicing a concatenated Input or Register. + # + # Note that we can't just check len(values) == 1 after we set values to + # [None] because that doesn't work when there is only one element in the + # wire_matrix. We must distinguish between: + # + # 1. Slicing values to produce values[0] (this case). + # 2. Concatenating values[0] to produce values (next case). + # + # But len(values) == 1 in both cases. The slice in (1) and concatenate + # in (2) are both no-ops, but we have to get the direction right. In the + # first case, values[0] is driven by values, and in the second case, + # values is driven by values[0]. + slicing = True + values = [None] + elif component_type is Input or component_type is Register: + values = [None for _ in range(self._size)] + + self._components = [None for i in range(len(schema))] + if slicing: + # Concatenated value was provided. Slice it into components. + _slice( + block=block, + schema=schema, + bitwidth=self._bitwidth, + component_type=component_type, + name=name, + concatenated=concatenated, + components=self._components, + concatenated_value=values[0], + ) + else: + if len(values) != len(schema): + msg = ( + "wire_matrix constructor expects 1 value to slice, or " + f"{len(schema)} values to concatenate (received {len(values)} " + "values)" + ) + raise PyrtlError(msg) + # Component values were provided; concatenate them. + _concatenate( + block=block, + schema=schema, + component_type=component_type, + name=name, + concatenated=concatenated, + components=self._components, + component_map=values, + ) - return _WireMatrix + def _getitem(self, key): + return self._components[key] + + def _len(self): + return len(self._components) + + return type( + class_name, + (WrappedWireVector,), + { + "_component_bitwidth": component_bitwidth, + "_component_schema": component_schema, + "_size": size, + "_bitwidth": component_bitwidth * size, + "_is_wire_matrix": True, + "__init__": _init, + "__getitem__": _getitem, + "__len__": _len, + }, + ) def one_hot_to_binary(w: WireVectorLike) -> WireVector: diff --git a/pyrtl/rtllib/float/__init__.py b/pyrtl/rtllib/float/__init__.py new file mode 100644 index 00000000..a1e8148f --- /dev/null +++ b/pyrtl/rtllib/float/__init__.py @@ -0,0 +1,29 @@ +""" +Add, subtract, and multiply floating point numbers. + +Several standard ``Float`` formats like :class:`Float16` and :class:`Float32` are +predefined. Users may also define custom floating point formats. + +The main operators are :func:`add`, :func:`sub`, and :func:`mult`. These operators all +accept and return one of these ``Float`` formats. Inputs to an operator must share +the same ``Float`` format. The operator's output will be in the same ``Float`` format as +its inputs. +""" + +from .add_sub import add, sub +from .multiplication import mult +from .types import BFloat16, Float16, Float32, Float64, RoundingMode +from .utils import get_default_rounding_mode, set_default_rounding_mode + +__all__ = [ + "BFloat16", + "Float16", + "Float32", + "Float64", + "RoundingMode", + "add", + "get_default_rounding_mode", + "mult", + "sub", + "set_default_rounding_mode", +] diff --git a/pyrtl/rtllib/pyrtlfloat/_add_sub.py b/pyrtl/rtllib/float/add_sub.py similarity index 54% rename from pyrtl/rtllib/pyrtlfloat/_add_sub.py rename to pyrtl/rtllib/float/add_sub.py index 74ff2d6c..786af73f 100644 --- a/pyrtl/rtllib/pyrtlfloat/_add_sub.py +++ b/pyrtl/rtllib/float/add_sub.py @@ -1,54 +1,117 @@ import pyrtl - -from ._float_utils import ( - _fp_wire_struct, +from pyrtl.rtllib.float.types import FloatType, RoundingMode +from pyrtl.rtllib.float.utils import ( _RawResult, _RawResultGRS, _round_rne, check_kinds, + get_default_rounding_mode, make_denormals_zero, - make_inf, - make_largest_finite_number, - make_nan, - make_zero, + make_inf_like, + make_largest_finite_number_like, + make_nan_like, ) -from ._types import FPTypeProperties, PyrtlFloatConfig, RoundingMode def add( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, -) -> pyrtl.WireVector: + operand_a: FloatType, operand_b: FloatType, rounding_mode: RoundingMode = None +) -> FloatType: + """Performs floating point addition. + + The two operands must share the same ``Float`` type. Adding different floating point + types is not supported. + + Denormalized numbers are not supported. Denormalized numbers will be flushed to + zero. + + The return value's ``Float`` type will match the operand ``Float`` type. For + example, if you ``add`` two :class:`~.Float16`, the result will be a + :class:`~.Float16`. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + The following example computes ``1.0 + 2.0``. This is a bare-metal example, directly + manipulating the raw ``sign``, ``exponent``, and ``mantissa`` in IEEE 754 16-bit + floating point representation. See `IEEE 754 Internal Representation + `_ + for more details:: + + >>> import pyrtl.rtllib.float as rtlfloat + + >>> a = rtlfloat.Float16(name="a", component_type=pyrtl.Input) + >>> b = rtlfloat.Float16(name="b", component_type=pyrtl.Input) + + >>> sum = rtlfloat.Float16(name="sum", Float16=None) + >>> sum <<= rtlfloat.add(a, b) + + >>> # IEEE 754 numbers are stored as a sign bit, mantissa, and exponent. The + >>> # represented number's absolute value is 1.{mantissa} * 2 ** {exponent} + >>> # + >>> # All mantissas have this implied `1` before the binary point. {mantissa} + >>> # only stores the bits after this implied `1` and binary point. + >>> # + >>> # IEEE 754 exponents are stored with a bias to simplify comparisons. An + >>> # exponent {x} is stored as {x + exponent_bias}. + >>> exponent_bias = 2 ** (sum.exponent.bitwidth - 1) - 1 + + >>> # Create a=1.0, represented as 1.0 * 2 ** 0. + >>> a_one = {"a.sign": 0, "a.exponent": 0 + exponent_bias, "a.mantissa": 0} + + >>> # Create b=2.0, represented as 1.0 * 2 ** 1. + >>> b_two = {"b.sign": 0, "b.exponent": 1 + exponent_bias, "b.mantissa": 0} + + >>> sim = pyrtl.Simulation() + >>> sim.step(a_one | b_two) + + >>> # The sum should be 3.0, represented as 0b1.1 * 2 ** 1. + >>> # Note that this 0b1.1 is in binary! Multiplying by 2 is equivalent to + >>> # left-shifting by 1, and 0b1.1 << 1 == 0b11, which is 3 in decimal. + >>> sim.inspect("sum.sign") + 0 + >>> sim.inspect("sum.exponent") - exponent_bias + 1 + >>> bin(sim.inspect("sum.mantissa")) + '0b1000000000' + >>> bin(1 << (sum.mantissa.bitwidth - 1)) + '0b1000000000' + + :param operand_a: + :param operand_b: + :param rounding_mode: Rounding mode, defaults to :attr:`~.RoundingMode.RNE`. The + default can be changed with :func:`.set_default_rounding_mode`. + + :return: The sum, as an instance of the operand ``Float`` type. """ - Performs floating point addition of two WireVectors. + if rounding_mode is None: + rounding_mode = get_default_rounding_mode() - :param config: Configuration for the floating point type and rounding mode. - :param operand_a: The first floating point operand as a WireVector. - :param operand_b: The second floating point operand as a WireVector. - :return: The result of the addition as a WireVector. - """ - fp_type_props = config.fp_type_properties - rounding_mode = config.rounding_mode - num_exp_bits = fp_type_props.num_exponent_bits - num_mant_bits = fp_type_props.num_mantissa_bits - FP = _fp_wire_struct(num_exp_bits, num_mant_bits) + if type(operand_a) is not type(operand_b): + msg = ( + f"Different operand types ({type(operand_a)}, {type(operand_b)}) are not " + "supported." + ) + raise pyrtl.PyrtlError(msg) # Denormalized numbers are not supported, so we flush them to zero. - operands = tuple( - make_denormals_zero(fp_type_props, op) for op in (operand_a, operand_b) - ) - fps = _sort_operands(FP, operands) + operands = tuple(make_denormals_zero(op) for op in (operand_a, operand_b)) + sorted_operands = _sort_operands(operands) del operands # Align mantissas and compute both the addition and subtraction results. - smaller_mantissa_shifted_grs, larger_mantissa_extended = _align_mantissa(fps) + smaller_mantissa_shifted_grs, larger_mantissa_extended = _align_mantissa( + sorted_operands + ) sum_result, sum_carry = _add_operands( - fps[1].exponent, smaller_mantissa_shifted_grs, larger_mantissa_extended + sorted_operands[1].exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, ) difference_result, num_leading_zeros = _sub_operands( - num_mant_bits, - fps[1].exponent, + operand_a.mantissa.bitwidth, + sorted_operands[1].exponent, smaller_mantissa_shifted_grs, larger_mantissa_extended, ) @@ -56,8 +119,7 @@ def add( # Select the correct result based on operand signs, then round if needed. raw_result, rounding_exponent_incremented = _select_and_round( - fp_type_props, - fps, + sorted_operands, sum_result, difference_result, rounding_mode, @@ -65,9 +127,7 @@ def add( del sum_result, difference_result return _handle_special_cases( - FP, - fp_type_props, - fps, + sorted_operands, raw_result, sum_carry, num_leading_zeros, @@ -77,72 +137,122 @@ def add( def sub( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, -) -> pyrtl.WireVector: - """ - Performs floating point subtraction of two WireVectors. + operand_a: FloatType, operand_b: FloatType, rounding_mode: RoundingMode = None +) -> FloatType: + """Performs floating point subtraction. + + The two operands must share the same ``Float`` type. Subtracting different floating + point types is not supported. + + Denormalized numbers are not supported. Denormalized numbers will be flushed to + zero. + + The return value's ``Float`` type will match the operand ``Float`` type. For + example, if you ``sub`` two :class:`~.Float16`, the result will be a + :class:`~.Float16`. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + The following example computes ``1.0 - 2.0``. This is a bare-metal example, directly + manipulating the raw ``sign``, ``exponent``, and ``mantissa`` in IEEE 754 16-bit + floating point representation. See the documentation for :func:`add` and `IEEE 754 + Internal Representation + `_ + for more details:: + + >>> import pyrtl.rtllib.float as rtlfloat + + >>> a = rtlfloat.Float16(name="a", component_type=pyrtl.Input) + >>> b = rtlfloat.Float16(name="b", component_type=pyrtl.Input) + + >>> difference = rtlfloat.Float16(name="difference", Float16=None) + >>> difference <<= rtlfloat.sub(a, b) + + >>> # See the `add` example for IEEE 754 representation background. + >>> exponent_bias = 2 ** (difference.exponent.bitwidth - 1) - 1 + + >>> # Create a=1.0, represented as 1.0 * 2 ** 0. + >>> a_one = {"a.sign": 0, "a.exponent": 0 + exponent_bias, "a.mantissa": 0} + + >>> # Create b=2.0, represented as 1.0 * 2 ** 1. + >>> b_two = {"b.sign": 0, "b.exponent": 1 + exponent_bias, "b.mantissa": 0} - :param config: Configuration for the floating point type and rounding mode. - :param operand_a: The first floating point operand as a WireVector. - :param operand_b: The second floating point operand as a WireVector. - :return: The result of the subtraction as a WireVector. + >>> sim = pyrtl.Simulation() + >>> sim.step(a_one | b_two) + + >>> # The difference should be -1.0, represented as -1.0 * 2 ** 0. + >>> sim.inspect("difference.sign") + 1 + >>> sim.inspect("difference.exponent") - exponent_bias + 0 + >>> sim.inspect("difference.mantissa") + 0 + + :param operand_a: + :param operand_b: + :param rounding_mode: Rounding mode, defaults to :attr:`~.RoundingMode.RNE`. The + default can be changed with :func:`.set_default_rounding_mode`. + + :return: The difference, as an instance of the operand ``Float`` type. """ - num_exp_bits = config.fp_type_properties.num_exponent_bits - num_mant_bits = config.fp_type_properties.num_mantissa_bits - operand_b_negated = operand_b ^ pyrtl.concat( - pyrtl.Const(1, bitwidth=1), - pyrtl.Const(0, bitwidth=num_exp_bits + num_mant_bits), + operand_b_negated = type(operand_b)( + sign=~operand_b.sign, + exponent=operand_b.exponent, + mantissa=operand_b.mantissa, ) - return add(config, operand_a, operand_b_negated) + return add(operand_a, operand_b_negated, rounding_mode) def _sort_operands( - FP, - operands: tuple, -) -> tuple: - """ - Sorts operands by absolute value. + operands: tuple[FloatType, FloatType], +) -> tuple[FloatType, FloatType]: + """Sorts ``operands`` by absolute value. - :param FP: The FP wire_struct class for the current floating point type. - :param operands: Tuple of two operand WireVectors with denormals flushed to zero. - :return: Tuple of (smaller_fp, larger_fp) as FP wire_struct instances. + :param operands: Tuple of two operand ``Floats`` with denormals flushed to zero. + + :return: Tuple of ``(smaller_operand, larger_operand)``, as instances of the operand + type. """ - total_bits = operands[0].bitwidth - sorted_operands = [pyrtl.WireVector(bitwidth=total_bits) for _ in range(2)] + sorted_operands = [ + pyrtl.WireVector(bitwidth=operands[0].bitwidth) for _ in range(2) + ] with pyrtl.conditional_assignment: - # Compare the lower (total_bits - 1) bits, which excludes the sign bit, - # to determine the operand with the smaller absolute value. - with operands[0][: total_bits - 1] < operands[1][: total_bits - 1]: + # Compare the concatenated exponents and mantissas to determine the operand with + # the smaller absolute value. + with pyrtl.concat(operands[0].exponent, operands[0].mantissa) < pyrtl.concat( + operands[1].exponent, operands[1].mantissa + ): sorted_operands[0] |= operands[0] sorted_operands[1] |= operands[1] with pyrtl.otherwise: sorted_operands[0] |= operands[1] sorted_operands[1] |= operands[0] - return tuple(FP(FP=op) for op in sorted_operands) + return tuple(type(operands[0])(_value=op) for op in sorted_operands) def _align_mantissa( - fps: tuple, + sorted_operands: tuple[FloatType, FloatType], ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: - """ - Aligns the smaller mantissa to the larger operand's exponent and computes - the guard, round, and sticky (GRS) bits for RNE rounding. + """Aligns the smaller mantissa to the larger operand's exponent and computes the + guard, round, and sticky (``GRS``) bits for ``RNE`` rounding. - :param fps: Tuple of (smaller_fp, larger_fp) as FP wire_struct instances. - :return: Tuple of (smaller_mantissa_shifted_grs, larger_mantissa_extended). + :param sorted_operands: Tuple of ``(smaller_operand, larger_operand)`` ``Floats``. + + :return: Tuple of ``(smaller_mantissa_shifted_grs, larger_mantissa_extended)``. """ - num_mant_bits = fps[0].mantissa.bitwidth + num_mant_bits = sorted_operands[0].mantissa.bitwidth mantissas_with_leading_1 = tuple( - pyrtl.concat(pyrtl.Const(1), fp.mantissa) for fp in fps + pyrtl.concat(pyrtl.Const(1), operand.mantissa) for operand in sorted_operands ) # Align mantissas by shifting the smaller one to match the larger's exponent. # Shifting the mantissa right by one divides the value by two, while adding # one to the exponent multiplies the value by two. Doing both simultaneously # preserves the value while matching the operands' exponents for addition. - shift_amount = fps[1].exponent - fps[0].exponent + shift_amount = sorted_operands[1].exponent - sorted_operands[0].exponent smaller_mantissa_shifted = pyrtl.shift_right_logical( mantissas_with_leading_1[0], shift_amount ) @@ -192,26 +302,25 @@ def _align_mantissa( def _select_and_round( - fp_type_props: FPTypeProperties, - fps: tuple, + sorted_operands: tuple[FloatType, FloatType], sum_result: _RawResultGRS, diff_result: _RawResultGRS, rounding_mode: RoundingMode, ) -> tuple: - """ - Selects the addition or subtraction result based on operand signs, then - applies RNE rounding if configured. + """Selects the addition or subtraction result based on operand signs, then applies + ``RNE`` rounding if configured. - :param fp_type_props: Floating point type properties. - :param fps: Tuple of (smaller_fp, larger_fp) as FP wire_struct instances. - :param sum_result: _RawResultGRS from the addition operation. - :param diff_result: _RawResultGRS from the subtraction operation. + :param sorted_operands: Tuple of ``(smaller_operand, larger_operand)`` as + ``Floats``. + :param sum_result: ``_RawResultGRS`` from addition. + :param diff_result: ``_RawResultGRS`` from subtraction. :param rounding_mode: The rounding mode to apply. - :return: Tuple of (_RawResult, rounding_exponent_incremented). The second - element is None for RTZ rounding mode. + + :return: Tuple of ``(_RawResult, rounding_exponent_incremented)``. The second + element is ``None`` for ``RTZ`` rounding mode. """ - num_exp_bits = fp_type_props.num_exponent_bits - num_mant_bits = fp_type_props.num_mantissa_bits + num_exp_bits = sorted_operands[0].exponent.bitwidth + num_mant_bits = sorted_operands[0].mantissa.bitwidth raw_result = _RawResult( exponent=pyrtl.WireVector(bitwidth=num_exp_bits), @@ -224,7 +333,7 @@ def _select_and_round( with pyrtl.conditional_assignment: # If the operands have the same sign, we perform addition. # For example, (+a) + (+b) or (-a) + (-b). - with fps[0].sign == fps[1].sign: + with sorted_operands[0].sign == sorted_operands[1].sign: raw_result.exponent |= sum_result.exponent raw_result.mantissa |= sum_result.mantissa if rounding_mode == RoundingMode.RNE: @@ -244,39 +353,36 @@ def _select_and_round( def _handle_special_cases( - FP, - fp_type_props, - fps: tuple, + sorted_operands: tuple[FloatType, FloatType], raw_result: _RawResult, sum_carry: pyrtl.WireVector, num_leading_zeros: pyrtl.WireVector, rounding_mode: RoundingMode, rounding_exponent_incremented, -): - """ - Handles special cases: NaN, infinity, zero, overflow, and underflow. +) -> FloatType: + """Handles special cases: NaN, infinity, zero, overflow, and underflow. - :param FP: The FP wire_struct class for the current floating point type. - :param fp_type_props: Floating point type properties. - :param fps: Tuple of (smaller_fp, larger_fp) as FP wire_struct instances. - :param raw_result: Pre-rounding result as a _RawResult. - :param sum_carry: Carry bit from the addition operation. + :param sorted_operands: Tuple of ``(smaller_operand, larger_operand)`` as + ``Floats``. + :param raw_result: Pre-rounding result as a ``_RawResult``. + :param sum_carry: Carry bit from addition. :param num_leading_zeros: Leading zero count from the subtraction normalization. :param rounding_mode: The rounding mode being used. :param rounding_exponent_incremented: Whether rounding incremented the exponent - (None for RTZ mode). - :return: The final FP wire_struct result. + (``None`` for ``RTZ`` mode). + + :return: The final sum or difference, as an instance of the ``sorted_operand`` + ``Float`` type. """ - num_exp_bits = fp_type_props.num_exponent_bits + num_exp_bits = sorted_operands[0].exponent.bitwidth - operand_kinds = tuple(check_kinds(fp) for fp in fps) + operand_kinds = tuple(check_kinds(operand) for operand in sorted_operands) # Pre-compute special value constants for use inside conditional_assignment. - final_result = FP(sign=None, exponent=None, mantissa=None) - nan_exp, nan_mant = make_nan(fp_type_props) - inf_exp, inf_mant = make_inf(fp_type_props) - zero_exp, zero_mant = make_zero(fp_type_props) - largest_exp, largest_mant = make_largest_finite_number(fp_type_props) + final_result = type(sorted_operands[0])(sign=None, exponent=None, mantissa=None) + nan_exp, nan_mant = make_nan_like(sorted_operands[0]) + inf_exp, inf_mant = make_inf_like(sorted_operands[0]) + largest_exp, largest_mant = make_largest_finite_number_like(sorted_operands[0]) # Check for overflow on addition. # We check for overflow by calculating the max value of the larger @@ -339,46 +445,48 @@ def _handle_special_cases( | ( operand_kinds[0].is_inf & operand_kinds[1].is_inf - & (fps[1].sign != fps[0].sign) + & (sorted_operands[1].sign != sorted_operands[0].sign) ) ): - final_result.sign |= fps[1].sign + final_result.sign |= sorted_operands[1].sign final_result.exponent |= nan_exp final_result.mantissa |= nan_mant # If either operand is infinity, result is infinity with that sign. with operand_kinds[0].is_inf: - final_result.sign |= fps[1].sign + final_result.sign |= sorted_operands[1].sign final_result.exponent |= inf_exp final_result.mantissa |= inf_mant with operand_kinds[1].is_inf: - final_result.sign |= fps[1].sign + final_result.sign |= sorted_operands[1].sign final_result.exponent |= inf_exp final_result.mantissa |= inf_mant # If operands are equal in magnitude but opposite in sign, the result is +0. with ( - (fps[0].mantissa == fps[1].mantissa) - & (fps[0].exponent == fps[1].exponent) - & (fps[1].sign != fps[0].sign) + (sorted_operands[0].mantissa == sorted_operands[1].mantissa) + & (sorted_operands[0].exponent == sorted_operands[1].exponent) + & (sorted_operands[1].sign != sorted_operands[0].sign) ): final_result.sign |= 0 - final_result.exponent |= zero_exp - final_result.mantissa |= zero_mant + final_result.exponent |= 0 + final_result.mantissa |= 0 # If either operand is zero, the result is the other operand. with operand_kinds[0].is_zero: - final_result.sign |= fps[1].sign - final_result.mantissa |= fps[1].mantissa - final_result.exponent |= fps[1].exponent + final_result.sign |= sorted_operands[1].sign + final_result.mantissa |= sorted_operands[1].mantissa + final_result.exponent |= sorted_operands[1].exponent with operand_kinds[1].is_zero: - final_result.sign |= fps[0].sign - final_result.mantissa |= fps[0].mantissa - final_result.exponent |= fps[0].exponent + final_result.sign |= sorted_operands[0].sign + final_result.mantissa |= sorted_operands[0].mantissa + final_result.exponent |= sorted_operands[0].exponent # Checks if an addition was performed and the result overflowed. - with (fps[0].sign == fps[1].sign) & (fps[1].exponent > exponent_max_value): - final_result.sign |= fps[1].sign + with (sorted_operands[0].sign == sorted_operands[1].sign) & ( + sorted_operands[1].exponent > exponent_max_value + ): + final_result.sign |= sorted_operands[1].sign # IEEE 754 Section 7.4: On overflow, RNE rounds to infinity, # while truncation rounds to the largest finite number. if rounding_mode == RoundingMode.RNE: @@ -389,13 +497,15 @@ def _handle_special_cases( final_result.mantissa |= largest_mant # Checks if a subtraction was performed and the result underflowed. - with (fps[0].sign != fps[1].sign) & (fps[1].exponent < exponent_min_value): - final_result.sign |= fps[1].sign - final_result.exponent |= zero_exp - final_result.mantissa |= zero_mant + with (sorted_operands[0].sign != sorted_operands[1].sign) & ( + sorted_operands[1].exponent < exponent_min_value + ): + final_result.sign |= sorted_operands[1].sign + final_result.exponent |= 0 + final_result.mantissa |= 0 # Otherwise no special cases apply: this is the common case. with pyrtl.otherwise: - final_result.sign |= fps[1].sign + final_result.sign |= sorted_operands[1].sign final_result.exponent |= raw_result.exponent final_result.mantissa |= raw_result.mantissa @@ -407,14 +517,14 @@ def _add_operands( smaller_mantissa_shifted_grs: pyrtl.WireVector, larger_mantissa_extended: pyrtl.WireVector, ) -> tuple[_RawResultGRS, pyrtl.WireVector]: - """ - Helper function for performing addition of two floating point mantissas. + """Helper function for performing addition of two floating point mantissas. :param larger_operand_exponent: Exponent of the larger operand. - :param smaller_mantissa_shifted_grs: Mantissa of the smaller operand - shifted to align with the larger operand and concatenated with GRS. + :param smaller_mantissa_shifted_grs: Mantissa of the smaller operand shifted to + align with the larger operand and concatenated with ``GRS``. :param larger_mantissa_extended: Larger mantissa with three zeros. - :return: Tuple of (_RawResultGRS, carry bit). + + :return: Tuple of ``(_RawResultGRS, carry_bit)``. """ sum_mantissa_grs = pyrtl.WireVector() sum_mantissa_grs <<= larger_mantissa_extended + smaller_mantissa_shifted_grs @@ -430,7 +540,9 @@ def _add_operands( sum_exponent = pyrtl.select( sum_carry, larger_operand_exponent + 1, larger_operand_exponent ) - return _RawResultGRS(sum_exponent, sum_mantissa, sum_grs), sum_carry + return _RawResultGRS( + exponent=sum_exponent, mantissa=sum_mantissa, grs=sum_grs + ), sum_carry def _sub_operands( @@ -439,15 +551,15 @@ def _sub_operands( smaller_mantissa_shifted_grs: pyrtl.WireVector, larger_mantissa_extended: pyrtl.WireVector, ) -> tuple[_RawResultGRS, pyrtl.WireVector]: - """ - Helper function for performing subtraction of two floating point mantissas. + """Helper function for performing subtraction of two floating point mantissas. :param num_mant_bits: Number of mantissa bits. :param larger_operand_exponent: Exponent of the larger operand. - :param smaller_mantissa_shifted_grs: Mantissa of the smaller operand - shifted to align with the larger operand and concatenated with GRS. + :param smaller_mantissa_shifted_grs: Mantissa of the smaller operand shifted to + align with the larger operand and concatenated with ``GRS``. :param larger_mantissa_extended: Larger mantissa with three zeros. - :return: Tuple of (_RawResultGRS, num leading zeros). + + :return: Tuple of ``(_RawResultGRS, num_leading_zeros)``. """ # Priority encoder that counts the number of leading zeros in a WireVector. @@ -475,5 +587,5 @@ def leading_zero_priority_encoder(wire: pyrtl.WireVector, length: int): # Adjust the exponent by subtracting the number of leading zeros. difference_exponent = larger_operand_exponent - num_leading_zeros return _RawResultGRS( - difference_exponent, difference_mantissa, difference_grs + exponent=difference_exponent, mantissa=difference_mantissa, grs=difference_grs ), num_leading_zeros diff --git a/pyrtl/rtllib/pyrtlfloat/_multiplication.py b/pyrtl/rtllib/float/multiplication.py similarity index 65% rename from pyrtl/rtllib/pyrtlfloat/_multiplication.py rename to pyrtl/rtllib/float/multiplication.py index 7b3925b8..a938b27d 100644 --- a/pyrtl/rtllib/pyrtlfloat/_multiplication.py +++ b/pyrtl/rtllib/float/multiplication.py @@ -1,68 +1,117 @@ -import pyrtl +# This file should be named `mult.py` for symmetry with `add_sub.py`, but is instead +# named `multiplication.py` so the module has a name. + +# `__init__.py` creates `pyrtl.rtllib.float.mult`, collides with this module's name, if +# this file were named `mult.py`. +# +# This module currently needs to be named to run its `doctest`s. -from ._float_utils import ( - _fp_wire_struct, +import pyrtl +from pyrtl.rtllib.float.types import FloatType, RoundingMode +from pyrtl.rtllib.float.utils import ( _RawResult, _round_rne, check_kinds, + get_default_rounding_mode, make_denormals_zero, - make_inf, - make_largest_finite_number, - make_nan, - make_zero, + make_inf_like, + make_largest_finite_number_like, + make_nan_like, ) -from ._types import FPTypeProperties, PyrtlFloatConfig, RoundingMode -def mul( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, -) -> pyrtl.WireVector: - """ - Performs floating point multiplication of two WireVectors. +def mult( + operand_a: FloatType, operand_b: FloatType, rounding_mode: RoundingMode = None +) -> FloatType: + """Performs floating point multiplication. + + The two operands must share the same ``Float`` type. Multiplying different floating + point types is not supported. + + Denormalized numbers are not supported. Denormalized numbers will be flushed to + zero. + + The return value's ``Float`` type will match the operand ``Float`` type. For + example, if you ``mult`` two :class:`~.Float16`, the result will be a + :class:`~.Float16`. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + The following example computes ``2.0 * 4.0``. This is a bare-metal example, directly + manipulating the raw ``sign``, ``exponent``, and ``mantissa`` in IEEE 754 16-bit + floating point representation. See the documentation for :func:`add` and `IEEE 754 + Internal Representation + `_ + for more details:: + + >>> import pyrtl.rtllib.float as rtlfloat + + >>> a = rtlfloat.Float16(name="a", component_type=pyrtl.Input) + >>> b = rtlfloat.Float16(name="b", component_type=pyrtl.Input) + + >>> product = rtlfloat.Float16(name="product", Float16=None) + >>> product <<= rtlfloat.mult(a, b) + + >>> # See the `add` example for IEEE 754 representation background. + >>> exponent_bias = 2 ** (product.exponent.bitwidth - 1) - 1 + + >>> # Create a=2.0, represented as 1.0 * 2 ** 1. + >>> a_two = {"a.sign": 0, "a.exponent": 1 + exponent_bias, "a.mantissa": 0} + + >>> # Create b=4.0, represented as 1.0 * 2 ** 2. + >>> b_four = {"b.sign": 0, "b.exponent": 2 + exponent_bias, "b.mantissa": 0} + + >>> sim = pyrtl.Simulation() + >>> sim.step(a_two | b_four) - :param config: Configuration for the floating point type and rounding mode. - :param operand_a: The first floating point operand as a WireVector. - :param operand_b: The second floating point operand as a WireVector. - :return: The result of the multiplication as a WireVector. + >>> # The product should be 8.0, represented as 1.0 * 2 ** 3. + >>> sim.inspect("product.sign") + 0 + >>> sim.inspect("product.exponent") - exponent_bias + 3 + >>> sim.inspect("product.mantissa") + 0 + + :param operand_a: + :param operand_b: + :param rounding_mode: Rounding mode, defaults to :attr:`~.RoundingMode.RNE`. The + default can be changed with :func:`.set_default_rounding_mode`. + + :return: The product, as an instance of the operand ``Float`` type. """ - fp_type_props = config.fp_type_properties - rounding_mode = config.rounding_mode - num_exp_bits = fp_type_props.num_exponent_bits - num_mant_bits = fp_type_props.num_mantissa_bits + if rounding_mode is None: + rounding_mode = get_default_rounding_mode() - # Denormalized numbers are not supported, so we flush them to zero. - operands = tuple( - make_denormals_zero(fp_type_props, op) for op in (operand_a, operand_b) - ) + if type(operand_a) is not type(operand_b): + msg = ( + f"Different operand types ({type(operand_a)}, {type(operand_b)}) are not " + "supported." + ) + raise pyrtl.PyrtlError(msg) - # Extract the sign, exponent, and mantissa of both operands. - FP = _fp_wire_struct(num_exp_bits, num_mant_bits) - fps = tuple(FP(FP=op) for op in operands) - del operands + # Denormalized numbers are not supported, so we flush them to zero. + operands = tuple(make_denormals_zero(op) for op in (operand_a, operand_b)) - result_sign = fps[0].sign ^ fps[1].sign + result_sign = operands[0].sign ^ operands[1].sign # Compute the product exponent and mantissa. - operand_exponent_sums, product_exponent, product_mantissa = _multiply( - fps, - num_exp_bits, - ) + operand_exponent_sums, product_exponent, product_mantissa = _multiply(operands) # Normalize the product and perform rounding. raw_result, need_to_normalize, exponent_incremented = _normalize_and_round( product_exponent, product_mantissa, - fp_type_props, + operand_a.exponent.bitwidth, + operand_a.mantissa.bitwidth, rounding_mode, ) del product_mantissa, product_exponent return _handle_special_cases( - FP, - fp_type_props, - fps, + operands, result_sign, raw_result, operand_exponent_sums, @@ -73,17 +122,14 @@ def mul( def _multiply( - fps: tuple, - num_exp_bits: int, + operands: tuple, ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: - """ - Computes the sum of operand exponents, the product exponent, and the raw - product mantissa. + """Computes the sum of operand exponents, the product exponent, and the raw product + mantissa. - :param fps: Tuple of FP wire_struct instances for the two operands. - :param num_exp_bits: Number of exponent bits. - :return: Tuple of (operand_exponent_sums, product_exponent, - product_mantissa). + :param operands: Tuple of ``Floats`` for the two operands. + + :return: Tuple of ``(operand_exponent_sums, product_exponent, product_mantissa)``. """ # IEEE-754 floating point numbers have a bias: # https://en.wikipedia.org/wiki/Exponent_bias @@ -91,12 +137,14 @@ def _multiply( # The sum of the stored exponents of the operands is (real0 + bias) + (real1 + bias) # = real0 + real1 + 2*bias. # Subtracting bias gives the stored exponent of the product: real0 + real1 + bias. - operand_exponent_sums = fps[0].exponent + fps[1].exponent - exponent_bias = 2 ** (num_exp_bits - 1) - 1 + operand_exponent_sums = operands[0].exponent + operands[1].exponent + exponent_bias = 2 ** (operands[0].exponent.bitwidth - 1) - 1 product_exponent = operand_exponent_sums - pyrtl.Const(exponent_bias) # Extract the mantissa of both operands and add the implicit leading 1. - mantissas = tuple(pyrtl.concat(pyrtl.Const(1), fp.mantissa) for fp in fps) + mantissas = tuple( + pyrtl.concat(pyrtl.Const(1), operand.mantissa) for operand in operands + ) product_mantissa = mantissas[0] * mantissas[1] return operand_exponent_sums, product_exponent, product_mantissa @@ -105,22 +153,19 @@ def _multiply( def _normalize_and_round( product_exponent: pyrtl.WireVector, product_mantissa: pyrtl.WireVector, - fp_type_props: FPTypeProperties, + num_exp_bits: int, + num_mant_bits: int, rounding_mode: RoundingMode, ) -> tuple: - """ - Normalizes the product mantissa and applies rounding if configured. + """Normalizes the product mantissa and applies rounding if configured. - :param product_exponent: The product exponent (sum of operand exponents - minus bias). + :param product_exponent: The product exponent (sum of operand exponents minus bias). :param product_mantissa: Raw product of the two mantissas (with implicit 1s). - :param fp_type_props: Floating point type properties. :param rounding_mode: The rounding mode to apply. - :return: Tuple of (_RawResult, need_to_normalize, exponent_incremented). - exponent_incremented is None for RTZ rounding mode. + + :return: Tuple of ``(_RawResult, need_to_normalize, exponent_incremented)``. + ``exponent_incremented`` is ``None`` for ``RTZ`` rounding mode. """ - num_exp_bits = fp_type_props.num_exponent_bits - num_mant_bits = fp_type_props.num_mantissa_bits # We're multiplying two numbers that both have the form 1. in # binary. The product's binary point sits just after its second-most # significant bit, giving the form ab.cdef... where each letter is one bit. @@ -185,9 +230,7 @@ def _normalize_and_round( def _handle_special_cases( - FP, - fp_type_props: FPTypeProperties, - fps: tuple, + operands: tuple, result_sign: pyrtl.WireVector, raw_result: _RawResult, operand_exponent_sums: pyrtl.WireVector, @@ -195,34 +238,31 @@ def _handle_special_cases( exponent_incremented, rounding_mode: RoundingMode, ): - """ - Handles special cases: NaN, infinity, zero, overflow, and underflow. + """Handles special cases: NaN, infinity, zero, overflow, and underflow. - :param FP: The FP wire_struct class for the current floating point type. - :param fp_type_props: Floating point type properties. - :param fps: Tuple of FP wire_struct instances for the two operands. + :param operands: Tuple of ``Floats`` for the two operands. :param result_sign: Sign bit of the result. - :param raw_result: Normalized (and possibly rounded) result as a _RawResult. + :param raw_result: Normalized (and possibly rounded) result as a ``_RawResult``. :param operand_exponent_sums: Sum of the two operand exponents. :param need_to_normalize: Whether the product mantissa required normalization. - :param exponent_incremented: Whether rounding incremented the exponent - (None for RTZ mode). + :param exponent_incremented: Whether rounding incremented the exponent (``None`` for + ``RTZ`` mode). :param rounding_mode: The rounding mode being used. - :return: The final FP wire_struct result. + + :return: The final product, as an instance of the operand ``Float`` type. """ - num_exp_bits = fp_type_props.num_exponent_bits + num_exp_bits = operands[0].exponent.bitwidth exponent_bias = 2 ** (num_exp_bits - 1) - 1 # Check whether operands are special: NaN, infinity, zero, or denormalized. - operand_kinds = tuple(check_kinds(fp) for fp in fps) + operand_kinds = tuple(check_kinds(operand) for operand in operands) # Pre-compute special value constants for use inside conditional_assignment. - result = FP(sign=None, exponent=None, mantissa=None) + result = type(operands[0])(sign=None, exponent=None, mantissa=None) result.sign <<= result_sign - nan_exp, nan_mant = make_nan(fp_type_props) - inf_exp, inf_mant = make_inf(fp_type_props) - zero_exp, zero_mant = make_zero(fp_type_props) - largest_exp, largest_mant = make_largest_finite_number(fp_type_props) + nan_exp, nan_mant = make_nan_like(operands[0]) + inf_exp, inf_mant = make_inf_like(operands[0]) + largest_exp, largest_mant = make_largest_finite_number_like(operands[0]) # We check for overflow and underflow by computing max and min exponent # values of the sum of operands' exponent before rounding and normalization. @@ -293,8 +333,8 @@ def _handle_special_cases( | operand_kinds[0].is_denormalized | operand_kinds[1].is_denormalized ): - result.exponent |= zero_exp - result.mantissa |= zero_mant + result.exponent |= 0 + result.mantissa |= 0 # Otherwise no special cases apply: this is the common case. with pyrtl.otherwise: result.exponent |= raw_result.exponent diff --git a/pyrtl/rtllib/float/types.py b/pyrtl/rtllib/float/types.py new file mode 100644 index 00000000..290c60c2 --- /dev/null +++ b/pyrtl/rtllib/float/types.py @@ -0,0 +1,77 @@ +from enum import Enum + +import pyrtl + + +class RoundingMode(Enum): + """Enum representing different rounding modes.""" + + RTZ = 1 + """Round towards zero (truncate).""" + + RNE = 2 + """Round to nearest, ties to even (default mode).""" + + +@pyrtl.wire_struct +class BFloat16: + """:class:`~pyrtl.wire_struct` representation of Google's ``bfloat16`` 16-bit + floating point format. + + :ivar sign: 1 bit + :ivar exponent: 8 bits + :ivar mantissa: 7 bits + """ + + sign: 1 + exponent: 8 + mantissa: 7 + + +@pyrtl.wire_struct +class Float16: + """:class:`~pyrtl.wire_struct` representation of IEEE 754 16-bit floating point + format. + + :ivar sign: 1 bit + :ivar exponent: 5 bits + :ivar mantissa: 10 bits + """ + + sign: 1 + exponent: 5 + mantissa: 10 + + +@pyrtl.wire_struct +class Float32: + """:class:`~pyrtl.wire_struct` representation of IEEE 754 32-bit floating point + format. + + :ivar sign: 1 bit + :ivar exponent: 8 bits + :ivar mantissa: 23 bits + """ + + sign: 1 + exponent: 8 + mantissa: 23 + + +@pyrtl.wire_struct +class Float64: + """:class:`~pyrtl.wire_struct` representation of IEEE 754 64-bit floating point + format. + + :ivar sign: 1 bit + :ivar exponent: 11 bits + :ivar mantissa: 52 bits + """ + + sign: 1 + exponent: 11 + mantissa: 52 + + +FloatType = BFloat16 | Float16 | Float32 | Float64 +"""Type alias for any floating point type.""" diff --git a/pyrtl/rtllib/float/utils.py b/pyrtl/rtllib/float/utils.py new file mode 100644 index 00000000..4c8f58ae --- /dev/null +++ b/pyrtl/rtllib/float/utils.py @@ -0,0 +1,190 @@ +from dataclasses import dataclass + +import pyrtl +from pyrtl.rtllib.float.types import FloatType, RoundingMode + +_default_rounding_mode = RoundingMode.RNE + + +def set_default_rounding_mode(rounding_mode: RoundingMode) -> None: + """Use ``rounding_mode`` by default for all future floating point operations. + + :param rounding_mode: + """ + global _default_rounding_mode + _default_rounding_mode = rounding_mode + + +def get_default_rounding_mode() -> RoundingMode: + """Return the current default rounding mode for floating point operations. + + :returns: + """ + return _default_rounding_mode + + +@pyrtl.wire_struct +class _GRS: + """Guard, round, and sticky (``GRS``) bits used for ``RNE`` rounding.""" + + guard: 1 + round: 1 + sticky: 1 + + +@pyrtl.wire_struct +class _FPKinds: + """Bits indicating the kind of a floating-point number.""" + + is_nan: 1 + is_inf: 1 + is_zero: 1 + is_denormalized: 1 + + +@dataclass +class _RawResult: + """Groups the ``exponent`` and ``mantissa`` ``WireVectors`` of a result.""" + + exponent: pyrtl.WireVector + mantissa: pyrtl.WireVector + + +@dataclass +class _RawResultGRS: + """ + Groups the ``exponent``, ``mantissa``, and ``GRS`` ``WireVectors`` of a result. + """ + + exponent: pyrtl.WireVector + mantissa: pyrtl.WireVector + grs: pyrtl.WireVector + + +def check_kinds(fp: FloatType) -> _FPKinds: + """Returns a ``_FPKinds`` ``wire_struct`` indicating the kind of the given floating + point number. + """ + max_exp = (1 << fp.exponent.bitwidth) - 1 + all_ones_exp = fp.exponent == max_exp + zero_exp = fp.exponent == 0 + zero_mant = fp.mantissa == 0 + return _FPKinds( + is_nan=all_ones_exp & ~zero_mant, + is_inf=all_ones_exp & zero_mant, + is_zero=zero_exp & zero_mant, + is_denormalized=zero_exp & ~zero_mant, + ) + + +def make_denormals_zero(operand: FloatType) -> FloatType: + """Returns zero if ``operand`` is denormalized, otherwise returns ``operand``. + https://en.wikipedia.org/wiki/Subnormal_number + """ + out = type(operand)(sign=operand.sign, exponent=operand.exponent, mantissa=None) + with pyrtl.conditional_assignment: + with operand.exponent == 0: + out.mantissa |= 0 + with pyrtl.otherwise: + out.mantissa |= operand.mantissa + + return out + + +def make_inf_like(operand: FloatType) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: + """Returns ``(exponent, mantissa)`` ``WireVectors`` representing infinity, with + bitwidths matching ``operand``. + + :param operand: A ``Float`` that determines the exponent and mantissa bitwidths. + + :return: Tuple of ``(exponent, mantissa)`` ``WireVectors``. + """ + num_exp_bits = operand.exponent.bitwidth + num_mant_bits = operand.mantissa.bitwidth + return ( + pyrtl.Const((1 << num_exp_bits) - 1, bitwidth=num_exp_bits), + pyrtl.Const(0, bitwidth=num_mant_bits), + ) + + +def make_nan_like(operand: FloatType) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: + """Returns ``(exponent, mantissa)`` ``WireVectors`` representing NaN, with + bitwidths matching ``operand``. + + :param operand: A ``Float`` that determines the exponent and mantissa bitwidths. + + :return: Tuple of ``(exponent, mantissa)`` ``WireVectors``. + """ + num_exp_bits = operand.exponent.bitwidth + num_mant_bits = operand.mantissa.bitwidth + return ( + pyrtl.Const((1 << num_exp_bits) - 1, bitwidth=num_exp_bits), + pyrtl.Const(1 << (num_mant_bits - 1), bitwidth=num_mant_bits), + ) + + +def make_largest_finite_number_like( + operand: FloatType, +) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: + """Returns ``(exponent, mantissa)`` ``WireVectors`` representing the largest finite + number, with bitwidths matching ``operand``. + + :param operand: A ``Float`` that determines the exponent and mantissa bitwidths. + + :return: Tuple of ``(exponent, mantissa)`` ``WireVectors``. + """ + num_exp_bits = operand.exponent.bitwidth + num_mant_bits = operand.mantissa.bitwidth + return ( + pyrtl.Const((1 << num_exp_bits) - 2, bitwidth=num_exp_bits), + pyrtl.Const((1 << num_mant_bits) - 1, bitwidth=num_mant_bits), + ) + + +def _round_rne( + raw_result: _RawResult, + raw_grs: pyrtl.WireVector, +) -> tuple: + """Round the floating point result using round to nearest, ties to even (``RNE``). + + Uses the ``GRS`` bits to determine if the result needs to be rounded up. + + :param raw_result: Pre-rounding result as a ``_RawResult``. + :param raw_grs: ``GRS`` bits of the raw result before rounding (guard=MSB, + sticky=LSB). + + :return: Tuple of ``(rounded _RawResult, rounding_exponent_incremented)``. + """ + num_mant_bits = raw_result.mantissa.bitwidth + num_exp_bits = raw_result.exponent.bitwidth + grs = _GRS(_GRS=raw_grs) + last = raw_result.mantissa[0] + # If guard bit is not set, number is closer to smaller value: no round up. + # If guard bit is set and round or sticky is set, round up. + # If guard bit is set but round and sticky are not set, value is exactly + # halfway. Following round-to-nearest ties-to-even, round up if last bit + # of mantissa is 1 (to make it even); otherwise do not round up. + # https://drilian.com/posts/2023.01.10-floating-point-numbers-and-rounding/ + round_up = grs.guard & (last | grs.round | grs.sticky) + rounded = _RawResult( + exponent=pyrtl.WireVector(bitwidth=num_exp_bits), + mantissa=pyrtl.WireVector(bitwidth=num_mant_bits), + ) + # Whether exponent was incremented due to rounding (for overflow check). + rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1) + with pyrtl.conditional_assignment: + with round_up: + # If rounding causes a mantissa overflow, we need to increment the exponent. + with raw_result.mantissa == (1 << num_mant_bits) - 1: + rounded.mantissa |= 0 + rounded.exponent |= raw_result.exponent + 1 + rounding_exponent_incremented |= 1 + with pyrtl.otherwise: + rounded.mantissa |= raw_result.mantissa + 1 + rounded.exponent |= raw_result.exponent + rounding_exponent_incremented |= 0 + with pyrtl.otherwise: + rounded.mantissa |= raw_result.mantissa + rounded.exponent |= raw_result.exponent + rounding_exponent_incremented |= 0 + return rounded, rounding_exponent_incremented diff --git a/pyrtl/rtllib/pyrtlfloat/__init__.py b/pyrtl/rtllib/pyrtlfloat/__init__.py deleted file mode 100644 index d9b64710..00000000 --- a/pyrtl/rtllib/pyrtlfloat/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode -from .floatoperations import ( - BFloat16Operations, - Float16Operations, - Float32Operations, - Float64Operations, - FloatOperations, -) - -__all__ = [ - "FloatingPointType", - "FPTypeProperties", - "PyrtlFloatConfig", - "RoundingMode", - "FloatOperations", - "BFloat16Operations", - "Float16Operations", - "Float32Operations", - "Float64Operations", -] diff --git a/pyrtl/rtllib/pyrtlfloat/_float_utils.py b/pyrtl/rtllib/pyrtlfloat/_float_utils.py deleted file mode 100644 index b9026653..00000000 --- a/pyrtl/rtllib/pyrtlfloat/_float_utils.py +++ /dev/null @@ -1,222 +0,0 @@ -import pyrtl - -from ._types import FPTypeProperties - - -def _fp_wire_struct(num_exp_bits, num_mant_bits): - """Creates a wire_struct class for an IEEE 754 floating point number. - - The returned class has three fields: sign (1 bit), exponent, and mantissa. - - :param num_exp_bits: Number of exponent bits. - :param num_mant_bits: Number of mantissa bits. - :return: A wire_struct class with sign, exponent, and mantissa fields. - """ - - @pyrtl.wire_struct - class FP: - sign: 1 - exponent: num_exp_bits - mantissa: num_mant_bits - - return FP - - -@pyrtl.wire_struct -class _GRS: - """Guard, round, and sticky bits used for RNE rounding.""" - - guard: 1 - round: 1 - sticky: 1 - - -@pyrtl.wire_struct -class _FPKinds: - """Bits indicating the kind of a floating-point number.""" - - is_nan: 1 - is_inf: 1 - is_zero: 1 - is_denormalized: 1 - - -class _RawResult: - """Groups the exponent and mantissa WireVectors of a result.""" - - def __init__( - self, - exponent: pyrtl.WireVector, - mantissa: pyrtl.WireVector, - ): - self.exponent = exponent - self.mantissa = mantissa - - -class _RawResultGRS(_RawResult): - """Groups the exponent, mantissa, and GRS WireVectors of a result.""" - - def __init__( - self, - exponent: pyrtl.WireVector, - mantissa: pyrtl.WireVector, - grs: pyrtl.WireVector, - ): - super().__init__(exponent, mantissa) - self.grs = grs - - -def check_kinds(fp) -> _FPKinds: - """ - Returns an _FPKinds wire struct indicating the kind of the given floating point - number. - - :param fp: FP wire_struct instance. - :return: _FPKinds instance. - """ - kinds = _FPKinds(is_nan=None, is_inf=None, is_zero=None, is_denormalized=None) - max_exp = (1 << fp.exponent.bitwidth) - 1 - all_ones_exp = fp.exponent == max_exp - zero_exp = fp.exponent == 0 - zero_mant = fp.mantissa == 0 - kinds.is_nan <<= all_ones_exp & ~zero_mant - kinds.is_inf <<= all_ones_exp & zero_mant - kinds.is_zero <<= zero_exp & zero_mant - kinds.is_denormalized <<= zero_exp & ~zero_mant - return kinds - - -def make_denormals_zero( - fp_prop: FPTypeProperties, wire: pyrtl.WireVector -) -> pyrtl.WireVector: - """ - Returns zero if denormalized, else original number. - https://en.wikipedia.org/wiki/Subnormal_number - - :param fp_prop: Floating point type properties. - :param wire: WireVector holding the floating point number. - :return: WireVector holding the resulting floating point number. - """ - FP = _fp_wire_struct(fp_prop.num_exponent_bits, fp_prop.num_mantissa_bits) - fp = FP(FP=wire) - out = pyrtl.WireVector( - bitwidth=fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits + 1 - ) - with pyrtl.conditional_assignment: - with fp.exponent == 0: - out |= pyrtl.concat( - fp.sign, - fp.exponent, - pyrtl.Const(0, bitwidth=fp_prop.num_mantissa_bits), - ) - with pyrtl.otherwise: - out |= wire - return out - - -def make_inf(fp_props: FPTypeProperties) -> tuple: - """ - Returns (exponent, mantissa) WireVectors representing infinity. - - :param fp_props: Floating point type properties. - :return: Tuple of (exponent, mantissa) WireVectors. - """ - num_exp_bits = fp_props.num_exponent_bits - num_mant_bits = fp_props.num_mantissa_bits - return ( - pyrtl.Const((1 << num_exp_bits) - 1, bitwidth=num_exp_bits), - pyrtl.Const(0, bitwidth=num_mant_bits), - ) - - -def make_nan(fp_props: FPTypeProperties) -> tuple: - """ - Returns (exponent, mantissa) WireVectors representing NaN. - - :param fp_props: Floating point type properties. - :return: Tuple of (exponent, mantissa) WireVectors. - """ - num_exp_bits = fp_props.num_exponent_bits - num_mant_bits = fp_props.num_mantissa_bits - return ( - pyrtl.Const((1 << num_exp_bits) - 1, bitwidth=num_exp_bits), - pyrtl.Const(1 << (num_mant_bits - 1), bitwidth=num_mant_bits), - ) - - -def make_zero(fp_props: FPTypeProperties) -> tuple: - """ - Returns (exponent, mantissa) WireVectors representing zero. - - :param fp_props: Floating point type properties. - :return: Tuple of (exponent, mantissa) WireVectors. - """ - num_exp_bits = fp_props.num_exponent_bits - num_mant_bits = fp_props.num_mantissa_bits - return ( - pyrtl.Const(0, bitwidth=num_exp_bits), - pyrtl.Const(0, bitwidth=num_mant_bits), - ) - - -def make_largest_finite_number(fp_props: FPTypeProperties) -> tuple: - """ - Returns (exponent, mantissa) WireVectors representing the largest finite number. - - :param fp_props: Floating point type properties. - :return: Tuple of (exponent, mantissa) WireVectors. - """ - num_exp_bits = fp_props.num_exponent_bits - num_mant_bits = fp_props.num_mantissa_bits - return ( - pyrtl.Const((1 << num_exp_bits) - 2, bitwidth=num_exp_bits), - pyrtl.Const((1 << num_mant_bits) - 1, bitwidth=num_mant_bits), - ) - - -def _round_rne( - raw_result: _RawResult, - raw_grs: pyrtl.WireVector, -) -> tuple: - """ - Round the floating point result using round to nearest, ties to even (RNE). - - Uses the GRS bits to determine if the result needs to be rounded up. - - :param raw_result: Pre-rounding result as a _RawResult. - :param raw_grs: GRS bits of the raw result before rounding (guard=MSB, sticky=LSB). - :return: Tuple of (rounded _RawResult, rounding_exponent_incremented). - """ - num_mant_bits = raw_result.mantissa.bitwidth - num_exp_bits = raw_result.exponent.bitwidth - grs = _GRS(_GRS=raw_grs) - last = raw_result.mantissa[0] - # If guard bit is not set, number is closer to smaller value: no round up. - # If guard bit is set and round or sticky is set, round up. - # If guard bit is set but round and sticky are not set, value is exactly - # halfway. Following round-to-nearest ties-to-even, round up if last bit - # of mantissa is 1 (to make it even); otherwise do not round up. - # https://drilian.com/posts/2023.01.10-floating-point-numbers-and-rounding/ - round_up = grs.guard & (last | grs.round | grs.sticky) - rounded = _RawResult( - exponent=pyrtl.WireVector(bitwidth=num_exp_bits), - mantissa=pyrtl.WireVector(bitwidth=num_mant_bits), - ) - # Whether exponent was incremented due to rounding (for overflow check). - rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1) - with pyrtl.conditional_assignment: - with round_up: - # If rounding causes a mantissa overflow, we need to increment the exponent. - with raw_result.mantissa == (1 << num_mant_bits) - 1: - rounded.mantissa |= 0 - rounded.exponent |= raw_result.exponent + 1 - rounding_exponent_incremented |= 1 - with pyrtl.otherwise: - rounded.mantissa |= raw_result.mantissa + 1 - rounded.exponent |= raw_result.exponent - rounding_exponent_incremented |= 0 - with pyrtl.otherwise: - rounded.mantissa |= raw_result.mantissa - rounded.exponent |= raw_result.exponent - rounding_exponent_incremented |= 0 - return rounded, rounding_exponent_incremented diff --git a/pyrtl/rtllib/pyrtlfloat/_types.py b/pyrtl/rtllib/pyrtlfloat/_types.py deleted file mode 100644 index a18df565..00000000 --- a/pyrtl/rtllib/pyrtlfloat/_types.py +++ /dev/null @@ -1,61 +0,0 @@ -from dataclasses import dataclass -from enum import Enum - - -class RoundingMode(Enum): - """ - Enum representing different rounding modes. - - Attributes: - RTZ (int): Round towards zero (truncate). - RNE (int): Round to nearest, ties to even (default mode). - """ - - RTZ = 1 - RNE = 2 - - -@dataclass(frozen=True) -class FPTypeProperties: - """ - Data class representing properties of a floating-point type. - - Attributes: - num_exponent_bits (int): Number of bits used for the exponent. - num_mantissa_bits (int): Number of bits used for the mantissa. - """ - - num_exponent_bits: int - num_mantissa_bits: int - - -class FloatingPointType(Enum): - """ - Enum representing different floating-point types. - - Attributes: - BFLOAT16 (FPTypeProperties): BFloat16 type properties. - FLOAT16 (FPTypeProperties): Float16 type properties. - FLOAT32 (FPTypeProperties): Float32 type properties. - FLOAT64 (FPTypeProperties): Float64 type properties. - """ - - BFLOAT16 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=7) - FLOAT16 = FPTypeProperties(num_exponent_bits=5, num_mantissa_bits=10) - FLOAT32 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=23) - FLOAT64 = FPTypeProperties(num_exponent_bits=11, num_mantissa_bits=52) - - -@dataclass(frozen=True) -class PyrtlFloatConfig: - """ - Data class representing the configuration for PyrtlFloat operations (floating point - type properties and rounding mode). - - Attributes: - fp_type_properties (FPTypeProperties): Properties of the floating-point type. - rounding_mode (RoundingMode): Rounding mode to be used. - """ - - fp_type_properties: FPTypeProperties - rounding_mode: RoundingMode diff --git a/pyrtl/rtllib/pyrtlfloat/floatoperations.py b/pyrtl/rtllib/pyrtlfloat/floatoperations.py deleted file mode 100644 index c3bdf98b..00000000 --- a/pyrtl/rtllib/pyrtlfloat/floatoperations.py +++ /dev/null @@ -1,180 +0,0 @@ -import pyrtl - -from ._add_sub import add, sub -from ._multiplication import mul -from ._types import FloatingPointType, PyrtlFloatConfig, RoundingMode - - -def _validate_operand_bitwidths( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, -) -> None: - """Validate that operand bitwidths match the floating point config.""" - fp_props = config.fp_type_properties - expected_bitwidth = fp_props.num_exponent_bits + fp_props.num_mantissa_bits + 1 - if operand_a.bitwidth != expected_bitwidth: - msg = ( - f"operand_a bitwidth {operand_a.bitwidth} does not match expected " - f"bitwidth {expected_bitwidth} for floating point type" - ) - raise pyrtl.PyrtlError(msg) - if operand_b.bitwidth != expected_bitwidth: - msg = ( - f"operand_b bitwidth {operand_b.bitwidth} does not match expected " - f"bitwidth {expected_bitwidth} for floating point type" - ) - raise pyrtl.PyrtlError(msg) - - -class FloatOperations: - """ - The rounding mode used for typed floating-point operations. - To change it, set this variable to the desired RoundingMode value. - """ - - default_rounding_mode = RoundingMode.RNE - - @staticmethod - def mul( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, - ) -> pyrtl.WireVector: - """ - Performs floating point multiplication of two WireVectors. The bitwidth of - the operands must be num_exponent_bits + num_mantissa_bits + 1, where - num_exponent_bits and num_mantissa_bits are defined in the config. - - :param config: Configuration for the floating point type and rounding mode. - :param operand_a: The first floating point operand as a WireVector. - :param operand_b: The second floating point operand as a WireVector. - :return: The result of the multiplication as a WireVector. - :raises PyrtlError: If operand bitwidths don't match config. - """ - _validate_operand_bitwidths(config, operand_a, operand_b) - return mul(config, operand_a, operand_b) - - @staticmethod - def add( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, - ) -> pyrtl.WireVector: - """ - Performs floating point addition of two WireVectors. The bitwidth of - the operands must be num_exponent_bits + num_mantissa_bits + 1, where - num_exponent_bits and num_mantissa_bits are defined in the config. - - :param config: Configuration for the floating point type and rounding mode. - :param operand_a: The first floating point operand as a WireVector. - :param operand_b: The second floating point operand as a WireVector. - :return: The result of the addition as a WireVector. - :raises PyrtlError: If operand bitwidths don't match config. - """ - _validate_operand_bitwidths(config, operand_a, operand_b) - return add(config, operand_a, operand_b) - - @staticmethod - def sub( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, - ) -> pyrtl.WireVector: - """ - Performs floating point subtraction of two WireVectors. The bitwidth of - the operands must be num_exponent_bits + num_mantissa_bits + 1, where - num_exponent_bits and num_mantissa_bits are defined in the config. - - :param config: Configuration for the floating point type and rounding mode. - :param operand_a: The first floating point operand as a WireVector. - :param operand_b: The second floating point operand as a WireVector. - :return: The result of the subtraction as a WireVector. - :raises PyrtlError: If operand bitwidths don't match config. - """ - _validate_operand_bitwidths(config, operand_a, operand_b) - return sub(config, operand_a, operand_b) - - -class _BaseTypedFloatOperations: - _fp_type: FloatingPointType = None - - @classmethod - def mul( - cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector - ) -> pyrtl.WireVector: - """ - Performs floating point multiplication of two WireVectors. The bitwidth of - the operands must match the bitwidth of the floating point type of this class. - - :param operand_a: The first floating point operand as a WireVector. - :param operand_b: The second floating point operand as a WireVector. - :return: The result of the multiplication as a WireVector. - """ - return FloatOperations.mul(cls._get_config(), operand_a, operand_b) - - @classmethod - def add( - cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector - ) -> pyrtl.WireVector: - """ - Performs floating point addition of two WireVectors. The bitwidth of - the operands must match the bitwidth of the floating point type of this class. - - :param operand_a: The first floating point operand as a WireVector. - :param operand_b: The second floating point operand as a WireVector. - :return: The result of the addition as a WireVector. - """ - return FloatOperations.add(cls._get_config(), operand_a, operand_b) - - @classmethod - def sub( - cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector - ) -> pyrtl.WireVector: - """ - Performs floating point subtraction of two WireVectors. The bitwidth of - the operands must match the bitwidth of the floating point type of this class. - - :param operand_a: The first floating point operand as a WireVector. - :param operand_b: The second floating point operand as a WireVector. - :return: The result of the subtraction as a WireVector. - """ - return FloatOperations.sub(cls._get_config(), operand_a, operand_b) - - @classmethod - def _get_config(cls) -> PyrtlFloatConfig: - return PyrtlFloatConfig( - cls._fp_type.value, FloatOperations.default_rounding_mode - ) - - -class BFloat16Operations(_BaseTypedFloatOperations): - """ - Operations for BFloat16 floating point type. - """ - - _fp_type = FloatingPointType.BFLOAT16 - - -class Float16Operations(_BaseTypedFloatOperations): - """ - Operations for Float16 floating point type. - """ - - _fp_type = FloatingPointType.FLOAT16 - - -class Float32Operations(_BaseTypedFloatOperations): - """ - Operations for Float32 floating point type. - """ - - _fp_type = FloatingPointType.FLOAT32 - - -class Float64Operations(_BaseTypedFloatOperations): - """ - Operations for Float64 floating point type. - """ - - _fp_type = FloatingPointType.FLOAT64 diff --git a/tests/rtllib/test_pyrtlfloat.py b/tests/rtllib/test_float.py similarity index 92% rename from tests/rtllib/test_pyrtlfloat.py rename to tests/rtllib/test_float.py index 5a088784..4d6ce8da 100644 --- a/tests/rtllib/test_pyrtlfloat.py +++ b/tests/rtllib/test_float.py @@ -1,7 +1,8 @@ +import doctest import unittest import pyrtl -from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode +import pyrtl.rtllib.float as rtlfloat # IEEE 754 Float16 special values FLOAT16_POS_ZERO = 0x0000 @@ -24,6 +25,20 @@ FLOAT16_DENORMALIZED = 0x0001 # Smallest denormalized number +class TestDocTests(unittest.TestCase): + """Test documentation examples.""" + + def test_add_sub_doctests(self): + failures, tests = doctest.testmod(m=rtlfloat.add_sub) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + + def test_mult_doctests(self): + failures, tests = doctest.testmod(m=rtlfloat.multiplication) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + + def float16_parts(sign, exp, mant): """Construct Float16 from sign, exponent, and mantissa.""" assert sign in (0, 1), f"sign must be 0 or 1, got {sign}" @@ -63,14 +78,14 @@ class TestAddition(unittest.TestCase): def setUp(self): pyrtl.reset_working_block() - self.a = pyrtl.Input(bitwidth=16, name="a") - self.b = pyrtl.Input(bitwidth=16, name="b") - FloatOperations.default_rounding_mode = RoundingMode.RNE + self.a = rtlfloat.Float16(name="a", concatenated_type=pyrtl.Input) + self.b = rtlfloat.Float16(name="b", concatenated_type=pyrtl.Input) + rtlfloat.set_default_rounding_mode(rtlfloat.RoundingMode.RNE) result_rne = pyrtl.Output(name="result_rne") - result_rne <<= Float16Operations.add(self.a, self.b) - FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rne <<= rtlfloat.add(self.a, self.b) + rtlfloat.set_default_rounding_mode(rtlfloat.RoundingMode.RTZ) result_rtz = pyrtl.Output(name="result_rtz") - result_rtz <<= Float16Operations.add(self.a, self.b) + result_rtz <<= rtlfloat.add(self.a, self.b) self.sim = pyrtl.Simulation() def assertFloat16Equal(self, output_name, expected): @@ -115,6 +130,16 @@ def test_add_opposite_signs_equal_magnitude(self): self.assertFloat16Equal("result_rne", FLOAT16_POS_ZERO) self.assertFloat16Equal("result_rtz", FLOAT16_POS_ZERO) + ############################ + # Error handling. + + def test_mismatched_types(self): + pyrtl.reset_working_block() + a = rtlfloat.Float16(name="a", concatenated_type=pyrtl.Input) + b = rtlfloat.Float32(name="b", concatenated_type=pyrtl.Input) + with self.assertRaises(pyrtl.PyrtlError): + rtlfloat.add(a, b) + ############################ # Rounding tests. @@ -303,14 +328,14 @@ class TestSubtraction(unittest.TestCase): def setUp(self): pyrtl.reset_working_block() - self.a = pyrtl.Input(bitwidth=16, name="a") - self.b = pyrtl.Input(bitwidth=16, name="b") - FloatOperations.default_rounding_mode = RoundingMode.RNE + self.a = rtlfloat.Float16(name="a", concatenated_type=pyrtl.Input) + self.b = rtlfloat.Float16(name="b", concatenated_type=pyrtl.Input) + rtlfloat.set_default_rounding_mode(rtlfloat.RoundingMode.RNE) result_rne = pyrtl.Output(name="result_rne") - result_rne <<= Float16Operations.sub(self.a, self.b) - FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rne <<= rtlfloat.sub(self.a, self.b) + rtlfloat.set_default_rounding_mode(rtlfloat.RoundingMode.RTZ) result_rtz = pyrtl.Output(name="result_rtz") - result_rtz <<= Float16Operations.sub(self.a, self.b) + result_rtz <<= rtlfloat.sub(self.a, self.b) self.sim = pyrtl.Simulation() def assertFloat16Equal(self, output_name, expected): @@ -459,14 +484,14 @@ class TestMultiplication(unittest.TestCase): def setUp(self): pyrtl.reset_working_block() - self.a = pyrtl.Input(bitwidth=16, name="a") - self.b = pyrtl.Input(bitwidth=16, name="b") - FloatOperations.default_rounding_mode = RoundingMode.RNE + self.a = rtlfloat.Float16(name="a", concatenated_type=pyrtl.Input) + self.b = rtlfloat.Float16(name="b", concatenated_type=pyrtl.Input) + rtlfloat.set_default_rounding_mode(rtlfloat.RoundingMode.RNE) result_rne = pyrtl.Output(name="result_rne") - result_rne <<= Float16Operations.mul(self.a, self.b) - FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rne <<= rtlfloat.mult(self.a, self.b) + rtlfloat.set_default_rounding_mode(rtlfloat.RoundingMode.RTZ) result_rtz = pyrtl.Output(name="result_rtz") - result_rtz <<= Float16Operations.mul(self.a, self.b) + result_rtz <<= rtlfloat.mult(self.a, self.b) self.sim = pyrtl.Simulation() def assertFloat16Equal(self, output_name, expected): @@ -506,6 +531,16 @@ def test_mul_one_point_five_times_one_point_five(self): self.assertFloat16Equal("result_rne", expected) self.assertFloat16Equal("result_rtz", expected) + ############################ + # Error handling. + + def test_mismatched_types(self): + pyrtl.reset_working_block() + a = rtlfloat.Float16(name="a", concatenated_type=pyrtl.Input) + b = rtlfloat.Float32(name="b", concatenated_type=pyrtl.Input) + with self.assertRaises(pyrtl.PyrtlError): + rtlfloat.mult(a, b) + ############################ # Rounding tests. diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 3ad5bfa6..6dab90a3 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -1350,6 +1350,8 @@ def test_concatenate(self): """Drive high and low, observe concatenated high+low.""" # Concatenates to 'byte'. byte = Byte(name="byte", high=0xA, low=0xB) + self.assertTrue(isinstance(byte, Byte)) + self.assertEqual(Byte.__name__, "Byte") self.assertEqual(len(byte), 2) self.assertEqual(byte.bitwidth, 8) self.assertEqual(len(pyrtl.as_wires(byte)), 8) @@ -1394,6 +1396,21 @@ def test_slice(self): self.assertTrue(isinstance(byte.low, pyrtl.Const)) self.assertEqual(byte.low.val, 0xD) + def test_underscore_value_slice(self): + """Drive concatenated high+low by setting the special _value name, observe high + and low. + """ + # Slices to 'byte.high' and 'byte.low'. + byte = Byte(name="byte", _value=0xCD) + + # Constants are sliced immediately. + self.assertTrue(isinstance(pyrtl.as_wires(byte), pyrtl.Const)) + self.assertEqual(byte.val, 0xCD) + self.assertTrue(isinstance(byte.high, pyrtl.Const)) + self.assertEqual(byte.high.val, 0xC) + self.assertTrue(isinstance(byte.low, pyrtl.Const)) + self.assertEqual(byte.low.val, 0xD) + def test_input_slice(self): """Given Input concatenated high+low, observe high and low.""" byte = Byte(name="byte", concatenated_type=pyrtl.Input) @@ -1588,7 +1605,7 @@ def test_anonymous_pixel_concatenate(self): # BitPair is an array of two single wires. -BitPair = pyrtl.wire_matrix(component_schema=1, size=2) +BitPair = pyrtl.wire_matrix(component_schema=1, size=2, class_name="BitPair") # Word is an array of two Bytes. This checks that a @wire_struct (Byte) can be a @@ -1619,6 +1636,8 @@ def setUp(self): def test_wire_matrix_slice(self): bitpair = BitPair(name="bitpair", values=[2]) + self.assertTrue(isinstance(bitpair, BitPair)) + self.assertEqual(BitPair.__name__, "BitPair") self.assertEqual(len(bitpair), 2) self.assertEqual(bitpair.bitwidth, 2) self.assertEqual(len(pyrtl.as_wires(bitpair)), 2)