Skip to content

Commit c970cae

Browse files
committed
Add doctest examples for floating point operations.
1 parent 52ff2e1 commit c970cae

4 files changed

Lines changed: 157 additions & 3 deletions

File tree

pyrtl/rtllib/float/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"""
1212

1313
from .add_sub import add, sub
14-
from .mult import mult
14+
from .multiplication import mult
1515
from .types import BFloat16, Float16, Float32, Float64, RoundingMode
1616
from .utils import get_default_rounding_mode, set_default_rounding_mode
1717

pyrtl/rtllib/float/add_sub.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,59 @@ def add(
2424
Denormalized numbers are not supported. Denormalized numbers will be flushed to
2525
zero.
2626
27-
The return value's ``Float``type will match the operand ``Float`` type. For example,
28-
if you ``add`` two :class:`~.Float16`, the result will be a :class:`~.Float16`.
27+
The return value's ``Float`` type will match the operand ``Float`` type. For
28+
example, if you ``add`` two :class:`~.Float16`, the result will be a
29+
:class:`~.Float16`.
30+
31+
.. doctest only::
32+
33+
>>> import pyrtl
34+
>>> pyrtl.reset_working_block()
35+
36+
The following example computes ``1.0 + 2.0``. This is a bare-metal example, directly
37+
manipulating the raw ``sign``, ``exponent``, and ``mantissa`` in IEEE 754 16-bit
38+
floating point representation. See `IEEE 754 Internal Representation
39+
<https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation>`_
40+
for more details::
41+
42+
>>> import pyrtl.rtllib.float as rtlfloat
43+
44+
>>> a = rtlfloat.Float16(name="a", component_type=pyrtl.Input)
45+
>>> b = rtlfloat.Float16(name="b", component_type=pyrtl.Input)
46+
47+
>>> sum = rtlfloat.Float16(name="sum", Float16=None)
48+
>>> sum <<= rtlfloat.add(a, b)
49+
50+
>>> # IEEE 754 numbers are stored as a sign bit, mantissa, and exponent. The
51+
>>> # represented number's absolute value is 1.{mantissa} * 2 ** {exponent}
52+
>>> #
53+
>>> # All mantissas have this implied `1` before the binary point. {mantissa}
54+
>>> # only stores the bits after this implied `1` and binary point.
55+
>>> #
56+
>>> # IEEE 754 exponents are stored with a bias to simplify comparisons. An
57+
>>> # exponent {x} is stored as {x + exponent_bias}.
58+
>>> exponent_bias = 2 ** (sum.exponent.bitwidth - 1) - 1
59+
60+
>>> # Create a=1.0, represented as 1.0 * 2 ** 0.
61+
>>> a_one = {"a.sign": 0, "a.exponent": 0 + exponent_bias, "a.mantissa": 0}
62+
63+
>>> # Create b=2.0, represented as 1.0 * 2 ** 1.
64+
>>> b_two = {"b.sign": 0, "b.exponent": 1 + exponent_bias, "b.mantissa": 0}
65+
66+
>>> sim = pyrtl.Simulation()
67+
>>> sim.step(a_one | b_two)
68+
69+
>>> # The sum should be 3.0, represented as 0b1.1 * 2 ** 1.
70+
>>> # Note that this 0b1.1 is in binary! Multiplying by 2 is equivalent to
71+
>>> # left-shifting by 1, and 0b1.1 << 1 == 0b11, which is 3 in decimal.
72+
>>> sim.inspect("sum.sign")
73+
0
74+
>>> sim.inspect("sum.exponent") - exponent_bias
75+
1
76+
>>> bin(sim.inspect("sum.mantissa"))
77+
'0b1000000000'
78+
>>> bin(1 << (sum.mantissa.bitwidth - 1))
79+
'0b1000000000'
2980
3081
:param operand_a:
3182
:param operand_b:
@@ -100,6 +151,46 @@ def sub(
100151
example, if you ``sub`` two :class:`~.Float16`, the result will be a
101152
:class:`~.Float16`.
102153
154+
.. doctest only::
155+
156+
>>> import pyrtl
157+
>>> pyrtl.reset_working_block()
158+
159+
The following example computes ``1.0 - 2.0``. This is a bare-metal example, directly
160+
manipulating the raw ``sign``, ``exponent``, and ``mantissa`` in IEEE 754 16-bit
161+
floating point representation. See the documentation for :func:`add` and `IEEE 754
162+
Internal Representation
163+
<https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation>`_
164+
for more details::
165+
166+
>>> import pyrtl.rtllib.float as rtlfloat
167+
168+
>>> a = rtlfloat.Float16(name="a", component_type=pyrtl.Input)
169+
>>> b = rtlfloat.Float16(name="b", component_type=pyrtl.Input)
170+
171+
>>> difference = rtlfloat.Float16(name="difference", Float16=None)
172+
>>> difference <<= rtlfloat.sub(a, b)
173+
174+
>>> # See the `add` example for IEEE 754 representation background.
175+
>>> exponent_bias = 2 ** (difference.exponent.bitwidth - 1) - 1
176+
177+
>>> # Create a=1.0, represented as 1.0 * 2 ** 0.
178+
>>> a_one = {"a.sign": 0, "a.exponent": 0 + exponent_bias, "a.mantissa": 0}
179+
180+
>>> # Create b=2.0, represented as 1.0 * 2 ** 1.
181+
>>> b_two = {"b.sign": 0, "b.exponent": 1 + exponent_bias, "b.mantissa": 0}
182+
183+
>>> sim = pyrtl.Simulation()
184+
>>> sim.step(a_one | b_two)
185+
186+
>>> # The difference should be -1.0, represented as -1.0 * 2 ** 0.
187+
>>> sim.inspect("difference.sign")
188+
1
189+
>>> sim.inspect("difference.exponent") - exponent_bias
190+
0
191+
>>> sim.inspect("difference.mantissa")
192+
0
193+
103194
:param operand_a:
104195
:param operand_b:
105196
:param rounding_mode: Rounding mode, defaults to :attr:`~.RoundingMode.RNE`. The
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# This file should be named `mult.py` for symmetry with `add_sub.py`, but is instead
2+
# named `multiplication.py` so the module has a name.
3+
4+
# `__init__.py` creates `pyrtl.rtllib.float.mult`, collides with this module's name, if
5+
# this file were named `mult.py`.
6+
#
7+
# This module currently needs to be named to run its `doctest`s.
8+
19
import pyrtl
210
from pyrtl.rtllib.float.types import FloatType, RoundingMode
311
from pyrtl.rtllib.float.utils import (
@@ -27,6 +35,46 @@ def mult(
2735
example, if you ``mult`` two :class:`~.Float16`, the result will be a
2836
:class:`~.Float16`.
2937
38+
.. doctest only::
39+
40+
>>> import pyrtl
41+
>>> pyrtl.reset_working_block()
42+
43+
The following example computes ``2.0 * 4.0``. This is a bare-metal example, directly
44+
manipulating the raw ``sign``, ``exponent``, and ``mantissa`` in IEEE 754 16-bit
45+
floating point representation. See the documentation for :func:`add` and `IEEE 754
46+
Internal Representation
47+
<https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation>`_
48+
for more details::
49+
50+
>>> import pyrtl.rtllib.float as rtlfloat
51+
52+
>>> a = rtlfloat.Float16(name="a", component_type=pyrtl.Input)
53+
>>> b = rtlfloat.Float16(name="b", component_type=pyrtl.Input)
54+
55+
>>> product = rtlfloat.Float16(name="product", Float16=None)
56+
>>> product <<= rtlfloat.mult(a, b)
57+
58+
>>> # See the `add` example for IEEE 754 representation background.
59+
>>> exponent_bias = 2 ** (product.exponent.bitwidth - 1) - 1
60+
61+
>>> # Create a=2.0, represented as 1.0 * 2 ** 1.
62+
>>> a_two = {"a.sign": 0, "a.exponent": 1 + exponent_bias, "a.mantissa": 0}
63+
64+
>>> # Create b=4.0, represented as 1.0 * 2 ** 2.
65+
>>> b_four = {"b.sign": 0, "b.exponent": 2 + exponent_bias, "b.mantissa": 0}
66+
67+
>>> sim = pyrtl.Simulation()
68+
>>> sim.step(a_two | b_four)
69+
70+
>>> # The product should be 8.0, represented as 1.0 * 2 ** 3.
71+
>>> sim.inspect("product.sign")
72+
0
73+
>>> sim.inspect("product.exponent") - exponent_bias
74+
3
75+
>>> sim.inspect("product.mantissa")
76+
0
77+
3078
:param operand_a:
3179
:param operand_b:
3280
:param rounding_mode: Rounding mode, defaults to :attr:`~.RoundingMode.RNE`. The

tests/rtllib/test_float.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import doctest
12
import unittest
23

34
import pyrtl
@@ -24,6 +25,20 @@
2425
FLOAT16_DENORMALIZED = 0x0001 # Smallest denormalized number
2526

2627

28+
class TestDocTests(unittest.TestCase):
29+
"""Test documentation examples."""
30+
31+
def test_add_sub_doctests(self):
32+
failures, tests = doctest.testmod(m=rtlfloat.add_sub)
33+
self.assertGreater(tests, 0)
34+
self.assertEqual(failures, 0)
35+
36+
def test_mult_doctests(self):
37+
failures, tests = doctest.testmod(m=rtlfloat.multiplication)
38+
self.assertGreater(tests, 0)
39+
self.assertEqual(failures, 0)
40+
41+
2742
def float16_parts(sign, exp, mant):
2843
"""Construct Float16 from sign, exponent, and mantissa."""
2944
assert sign in (0, 1), f"sign must be 0 or 1, got {sign}"

0 commit comments

Comments
 (0)