Skip to content

Commit 15ad996

Browse files
committed
Remove FloatWireVector
1 parent 94c4c68 commit 15ad996

5 files changed

Lines changed: 63 additions & 65 deletions

File tree

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode
2-
from .floatoperations import FloatOperations
3-
from .floatwirevector import Float16WireVector
2+
from .floatoperations import (
3+
BFloat16Operations,
4+
Float16Operations,
5+
Float32Operations,
6+
Float64Operations,
7+
FloatOperations,
8+
)
49

510
__all__ = [
611
"FloatingPointType",
712
"FPTypeProperties",
813
"PyrtlFloatConfig",
914
"RoundingMode",
1015
"FloatOperations",
11-
"Float16WireVector",
16+
"BFloat16Operations",
17+
"Float16Operations",
18+
"Float32Operations",
19+
"Float64Operations",
1220
]

pyrtl/rtllib/pyrtlfloat/floatoperations.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
from ._add_sub import AddSubHelper
44
from ._multiplication import MultiplicationHelper
5-
from ._types import PyrtlFloatConfig, RoundingMode
5+
from ._types import FloatingPointType, PyrtlFloatConfig, RoundingMode
66

77

88
class FloatOperations:
99
default_rounding_mode = RoundingMode.RNE
1010

1111
@staticmethod
12-
def multiply(
12+
def mul(
1313
config: PyrtlFloatConfig,
1414
operand_a: pyrtl.WireVector,
1515
operand_b: pyrtl.WireVector,
@@ -31,3 +31,47 @@ def sub(
3131
operand_b: pyrtl.WireVector,
3232
) -> pyrtl.WireVector:
3333
return AddSubHelper.sub(config, operand_a, operand_b)
34+
35+
36+
class _BaseTypedFloatOperations:
37+
_fp_type: FloatingPointType = None
38+
39+
@classmethod
40+
def mul(
41+
cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector
42+
) -> pyrtl.WireVector:
43+
return FloatOperations.mul(cls._get_config(), operand_a, operand_b)
44+
45+
@classmethod
46+
def add(
47+
cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector
48+
) -> pyrtl.WireVector:
49+
return FloatOperations.add(cls._get_config(), operand_a, operand_b)
50+
51+
@classmethod
52+
def sub(
53+
cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector
54+
) -> pyrtl.WireVector:
55+
return FloatOperations.sub(cls._get_config(), operand_a, operand_b)
56+
57+
@classmethod
58+
def _get_config(cls) -> PyrtlFloatConfig:
59+
return PyrtlFloatConfig(
60+
cls._fp_type.value, FloatOperations.default_rounding_mode
61+
)
62+
63+
64+
class BFloat16Operations(_BaseTypedFloatOperations):
65+
_fp_type = FloatingPointType.BFLOAT16
66+
67+
68+
class Float16Operations(_BaseTypedFloatOperations):
69+
_fp_type = FloatingPointType.FLOAT16
70+
71+
72+
class Float32Operations(_BaseTypedFloatOperations):
73+
_fp_type = FloatingPointType.FLOAT32
74+
75+
76+
class Float64Operations(_BaseTypedFloatOperations):
77+
_fp_type = FloatingPointType.FLOAT64

pyrtl/rtllib/pyrtlfloat/floatwirevector.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

tests/rtllib/pyrtlfloat/test_add_sub.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
import unittest
22

33
import pyrtl
4-
from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode
4+
from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode
55

66

77
class TestMultiplication(unittest.TestCase):
88
def setUp(self):
99
pyrtl.reset_working_block()
1010
a = pyrtl.Input(bitwidth=16, name="a")
1111
b = pyrtl.Input(bitwidth=16, name="b")
12-
a_floatwv = Float16WireVector()
13-
a_floatwv <<= a
14-
b_floatwv = Float16WireVector()
15-
b_floatwv <<= b
1612
FloatOperations.default_rounding_mode = RoundingMode.RNE
1713
result_add = pyrtl.Output(name="result_add")
18-
result_add <<= a_floatwv + b_floatwv
14+
result_add <<= Float16Operations.add(a, b)
1915
result_sub = pyrtl.Output(name="result_sub")
20-
result_sub <<= a_floatwv - b_floatwv
16+
result_sub <<= Float16Operations.sub(a, b)
2117
self.sim = pyrtl.Simulation()
2218

2319
def test_multiplication_simple(self):

tests/rtllib/pyrtlfloat/test_multiplication.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
11
import unittest
22

33
import pyrtl
4-
from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode
4+
from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode
55

66

77
class TestMultiplication(unittest.TestCase):
88
def setUp(self):
99
pyrtl.reset_working_block()
1010
a = pyrtl.Input(bitwidth=16, name="a")
1111
b = pyrtl.Input(bitwidth=16, name="b")
12-
a_floatwv = Float16WireVector()
13-
a_floatwv <<= a
14-
b_floatwv = Float16WireVector()
15-
b_floatwv <<= b
1612
FloatOperations.default_rounding_mode = RoundingMode.RNE
1713
result_rne = pyrtl.Output(name="result_rne")
18-
result_rne <<= a_floatwv * b_floatwv
14+
result_rne <<= Float16Operations.mul(a, b)
1915
FloatOperations.default_rounding_mode = RoundingMode.RTZ
2016
result_rtz = pyrtl.Output(name="result_rtz")
21-
result_rtz <<= a_floatwv * b_floatwv
17+
result_rtz <<= Float16Operations.mul(a, b)
2218
self.sim = pyrtl.Simulation()
2319

2420
def test_multiplication_simple(self):

0 commit comments

Comments
 (0)