Skip to content

Commit 94c4c68

Browse files
committed
Pyrtl floating point library
1 parent 8c706f6 commit 94c4c68

9 files changed

Lines changed: 789 additions & 0 deletions

File tree

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode
2+
from .floatoperations import FloatOperations
3+
from .floatwirevector import Float16WireVector
4+
5+
__all__ = [
6+
"FloatingPointType",
7+
"FPTypeProperties",
8+
"PyrtlFloatConfig",
9+
"RoundingMode",
10+
"FloatOperations",
11+
"Float16WireVector",
12+
]
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
import pyrtl
2+
3+
from ._float_utills import FloatUtils
4+
from ._types import PyrtlFloatConfig, RoundingMode
5+
6+
7+
class AddSubHelper:
8+
@staticmethod
9+
def add(
10+
config: PyrtlFloatConfig,
11+
operand_a: pyrtl.WireVector,
12+
operand_b: pyrtl.WireVector,
13+
) -> pyrtl.WireVector:
14+
fp_type_props = config.fp_type_properties
15+
rounding_mode = config.rounding_mode
16+
num_exp_bits = fp_type_props.num_exponent_bits
17+
num_mant_bits = fp_type_props.num_mantissa_bits
18+
total_bits = num_exp_bits + num_mant_bits + 1
19+
20+
operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a)
21+
operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b)
22+
23+
# operand_smaller is the operand with the smaller absolute value and
24+
# operand_larger is the operand with the larger absolute value
25+
operand_smaller = pyrtl.WireVector(bitwidth=total_bits)
26+
operand_larger = pyrtl.WireVector(bitwidth=total_bits)
27+
28+
with pyrtl.conditional_assignment:
29+
exponent_and_mantissa_len = num_mant_bits + num_exp_bits
30+
with (
31+
operand_a_daz[:exponent_and_mantissa_len]
32+
< operand_b_daz[:exponent_and_mantissa_len]
33+
):
34+
operand_smaller |= operand_a_daz
35+
operand_larger |= operand_b_daz
36+
with pyrtl.otherwise:
37+
operand_smaller |= operand_b_daz
38+
operand_larger |= operand_a_daz
39+
40+
smaller_operand_sign = FloatUtils.get_sign(fp_type_props, operand_smaller)
41+
larger_operand_sign = FloatUtils.get_sign(fp_type_props, operand_larger)
42+
smaller_operand_exponent = FloatUtils.get_exponent(
43+
fp_type_props, operand_smaller
44+
)
45+
larger_operand_exponent = FloatUtils.get_exponent(fp_type_props, operand_larger)
46+
smaller_operand_mantissa = pyrtl.concat(
47+
pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_smaller)
48+
)
49+
larger_operand_mantissa = pyrtl.concat(
50+
pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_larger)
51+
)
52+
53+
exponent_diff = larger_operand_exponent - smaller_operand_exponent
54+
smaller_mantissa_shifted = pyrtl.shift_right_logical(
55+
smaller_operand_mantissa, exponent_diff
56+
)
57+
grs = pyrtl.WireVector(bitwidth=3) # guard, round, sticky bits for rounding
58+
with pyrtl.conditional_assignment:
59+
with exponent_diff >= 2:
60+
guard_and_round = pyrtl.shift_right_logical(
61+
smaller_operand_mantissa, exponent_diff - 2
62+
)[:2]
63+
mask = (
64+
pyrtl.shift_left_logical(
65+
pyrtl.Const(1, bitwidth=num_mant_bits), exponent_diff - 2
66+
)
67+
- 1
68+
)
69+
sticky = (smaller_operand_mantissa & mask) != 0
70+
grs |= pyrtl.concat(guard_and_round, sticky)
71+
with exponent_diff == 1:
72+
grs |= pyrtl.concat(
73+
smaller_operand_mantissa[0], pyrtl.Const(0, bitwidth=2)
74+
)
75+
with pyrtl.otherwise:
76+
grs |= 0
77+
smaller_mantissa_shifted_grs = pyrtl.concat(smaller_mantissa_shifted, grs)
78+
larger_mantissa_extended = pyrtl.concat(
79+
larger_operand_mantissa, pyrtl.Const(0, bitwidth=3)
80+
)
81+
82+
sum_exponent, sum_mantissa, sum_grs, sum_carry = AddSubHelper._add_operands(
83+
larger_operand_exponent,
84+
smaller_mantissa_shifted_grs,
85+
larger_mantissa_extended,
86+
)
87+
88+
sub_exponent, sub_mantissa, sub_grs, num_leading_zeros = (
89+
AddSubHelper._sub_operands(
90+
num_mant_bits,
91+
larger_operand_exponent,
92+
smaller_mantissa_shifted_grs,
93+
larger_mantissa_extended,
94+
)
95+
)
96+
97+
# WireVectors for the raw addition or subtraction result, before handling
98+
# special cases
99+
raw_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits)
100+
raw_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits)
101+
if rounding_mode == RoundingMode.RNE:
102+
raw_result_grs = pyrtl.WireVector(bitwidth=3)
103+
104+
with pyrtl.conditional_assignment:
105+
with smaller_operand_sign == larger_operand_sign: # add
106+
raw_result_exponent |= sum_exponent
107+
raw_result_mantissa |= sum_mantissa
108+
if rounding_mode == RoundingMode.RNE:
109+
raw_result_grs |= sum_grs
110+
with pyrtl.otherwise: # sub
111+
raw_result_exponent |= sub_exponent
112+
raw_result_mantissa |= sub_mantissa
113+
if rounding_mode == RoundingMode.RNE:
114+
raw_result_grs |= sub_grs
115+
116+
if rounding_mode == RoundingMode.RNE:
117+
(
118+
raw_result_rounded_exponent,
119+
raw_result_rounded_mantissa,
120+
rounding_exponent_incremented,
121+
) = AddSubHelper._round(
122+
num_mant_bits,
123+
num_exp_bits,
124+
raw_result_exponent,
125+
raw_result_mantissa,
126+
raw_result_grs,
127+
)
128+
129+
smaller_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_smaller)
130+
larger_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_larger)
131+
smaller_operand_inf = FloatUtils.is_inf(fp_type_props, operand_smaller)
132+
larger_operand_inf = FloatUtils.is_inf(fp_type_props, operand_larger)
133+
smaller_operand_zero = FloatUtils.is_zero(fp_type_props, operand_smaller)
134+
larger_operand_zero = FloatUtils.is_zero(fp_type_props, operand_larger)
135+
136+
# WireVectors for the final result after handling special cases
137+
final_result_sign = pyrtl.WireVector(bitwidth=1)
138+
final_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits)
139+
final_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits)
140+
141+
# handle special cases
142+
with pyrtl.conditional_assignment:
143+
# if either operand is NaN or both operands are infinity of opposite signs,
144+
# the result is NaN
145+
with (
146+
smaller_operand_nan
147+
| larger_operand_nan
148+
| (
149+
smaller_operand_inf
150+
& larger_operand_inf
151+
& (larger_operand_sign != smaller_operand_sign)
152+
)
153+
):
154+
final_result_sign |= larger_operand_sign
155+
FloatUtils.make_output_NaN(
156+
fp_type_props, final_result_exponent, final_result_mantissa
157+
)
158+
# infinities
159+
with smaller_operand_inf:
160+
final_result_sign |= larger_operand_sign
161+
FloatUtils.make_output_inf(
162+
fp_type_props, final_result_exponent, final_result_mantissa
163+
)
164+
with larger_operand_inf:
165+
final_result_sign |= larger_operand_sign
166+
FloatUtils.make_output_inf(
167+
fp_type_props, final_result_exponent, final_result_mantissa
168+
)
169+
# +num + -num = +0
170+
with (
171+
(smaller_operand_mantissa == larger_operand_mantissa)
172+
& (smaller_operand_exponent == larger_operand_exponent)
173+
& (larger_operand_sign != smaller_operand_sign)
174+
):
175+
final_result_sign |= 0
176+
FloatUtils.make_output_zero(
177+
final_result_exponent, final_result_mantissa
178+
)
179+
with smaller_operand_zero:
180+
final_result_sign |= larger_operand_sign
181+
final_result_mantissa |= larger_operand_mantissa
182+
final_result_exponent |= larger_operand_exponent
183+
with larger_operand_zero:
184+
final_result_sign |= smaller_operand_sign
185+
final_result_mantissa |= smaller_operand_mantissa
186+
final_result_exponent |= smaller_operand_exponent
187+
# overflow and underflow
188+
initial_larger_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2)
189+
if rounding_mode == RoundingMode.RNE:
190+
larger_exponent_max_value = (
191+
initial_larger_exponent_max_value
192+
- sum_carry
193+
- rounding_exponent_incremented
194+
)
195+
else:
196+
larger_exponent_max_value = (
197+
initial_larger_exponent_max_value - sum_carry
198+
)
199+
initial_larger_exponent_min_value = pyrtl.Const(1)
200+
if rounding_mode == RoundingMode.RNE:
201+
larger_exponent_min_value = (
202+
initial_larger_exponent_min_value
203+
+ num_leading_zeros
204+
- rounding_exponent_incremented
205+
)
206+
else:
207+
larger_exponent_min_value = (
208+
initial_larger_exponent_min_value + num_leading_zeros
209+
)
210+
with (smaller_operand_sign == larger_operand_sign) & (
211+
larger_operand_exponent > larger_exponent_max_value
212+
): # detect overflow on addition
213+
final_result_sign |= larger_operand_sign
214+
if rounding_mode == RoundingMode.RNE:
215+
FloatUtils.make_output_inf(
216+
fp_type_props, final_result_exponent, final_result_mantissa
217+
)
218+
else:
219+
FloatUtils.make_output_largest_finite_number(
220+
fp_type_props, final_result_exponent, final_result_mantissa
221+
)
222+
with (smaller_operand_sign != larger_operand_sign) & (
223+
larger_operand_exponent < larger_exponent_min_value
224+
): # detect underflow on subtraction
225+
final_result_sign |= larger_operand_sign
226+
FloatUtils.make_output_zero(
227+
final_result_exponent, final_result_mantissa
228+
)
229+
with pyrtl.otherwise:
230+
final_result_sign |= larger_operand_sign
231+
if rounding_mode == RoundingMode.RNE:
232+
final_result_exponent |= raw_result_rounded_exponent
233+
final_result_mantissa |= raw_result_rounded_mantissa
234+
else:
235+
final_result_exponent |= raw_result_exponent
236+
final_result_mantissa |= raw_result_mantissa
237+
238+
return pyrtl.concat(
239+
final_result_sign, final_result_exponent, final_result_mantissa
240+
)
241+
242+
@staticmethod
243+
def sub(
244+
config: PyrtlFloatConfig,
245+
operand_a: pyrtl.WireVector,
246+
operand_b: pyrtl.WireVector,
247+
) -> pyrtl.WireVector:
248+
num_exp_bits = config.fp_type_properties.num_exponent_bits
249+
num_mant_bits = config.fp_type_properties.num_mantissa_bits
250+
operand_b_negated = operand_b ^ pyrtl.concat(
251+
pyrtl.Const(1, bitwidth=1),
252+
pyrtl.Const(0, bitwidth=num_exp_bits + num_mant_bits),
253+
)
254+
return AddSubHelper.add(config, operand_a, operand_b_negated)
255+
256+
@staticmethod
257+
def _add_operands(
258+
larger_operand_exponent: pyrtl.WireVector,
259+
smaller_mantissa_shifted_grs: pyrtl.WireVector,
260+
larger_mantissa_extended: pyrtl.WireVector,
261+
) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]:
262+
sum_mantissa_grs = pyrtl.WireVector()
263+
sum_mantissa_grs <<= larger_mantissa_extended + smaller_mantissa_shifted_grs
264+
sum_carry = sum_mantissa_grs[-1]
265+
sum_mantissa = pyrtl.select(
266+
sum_carry, sum_mantissa_grs[4:], sum_mantissa_grs[3:-1]
267+
)
268+
sum_grs = pyrtl.select(
269+
sum_carry,
270+
pyrtl.concat(sum_mantissa_grs[2:4], sum_mantissa_grs[:2] != 0),
271+
sum_mantissa_grs[:3],
272+
)
273+
sum_exponent = pyrtl.select(
274+
sum_carry, larger_operand_exponent + 1, larger_operand_exponent
275+
)
276+
return sum_exponent, sum_mantissa, sum_grs, sum_carry
277+
278+
@staticmethod
279+
def _sub_operands(
280+
num_mant_bits: int,
281+
larger_operand_exponent: pyrtl.WireVector,
282+
smaller_mantissa_shifted_grs: pyrtl.WireVector,
283+
larger_mantissa_extended: pyrtl.WireVector,
284+
) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]:
285+
def leading_zero_priority_encoder(wire: pyrtl.WireVector, length: int):
286+
out = pyrtl.WireVector(
287+
bitwidth=pyrtl.infer_val_and_bitwidth(length - 1).bitwidth
288+
)
289+
with pyrtl.conditional_assignment:
290+
for i in range(wire.bitwidth - 1, wire.bitwidth - length - 1, -1):
291+
with wire[i]:
292+
out |= wire.bitwidth - i - 1
293+
return out
294+
295+
sub_mantissa_grs = pyrtl.WireVector(bitwidth=num_mant_bits + 4)
296+
sub_mantissa_grs <<= larger_mantissa_extended - smaller_mantissa_shifted_grs
297+
num_leading_zeros = leading_zero_priority_encoder(
298+
sub_mantissa_grs, num_mant_bits + 1
299+
)
300+
sub_mantissa_grs_shifted = pyrtl.shift_left_logical(
301+
sub_mantissa_grs, num_leading_zeros
302+
)
303+
sub_mantissa = sub_mantissa_grs_shifted[3:]
304+
sub_grs = sub_mantissa_grs_shifted[:3]
305+
sub_exponent = larger_operand_exponent - num_leading_zeros
306+
return sub_exponent, sub_mantissa, sub_grs, num_leading_zeros
307+
308+
@staticmethod
309+
def _round(
310+
num_mant_bits: int,
311+
num_exp_bits: int,
312+
raw_result_exponent: pyrtl.WireVector,
313+
raw_result_mantissa: pyrtl.WireVector,
314+
raw_result_grs: pyrtl.WireVector,
315+
) -> tuple[pyrtl.WireVector, pyrtl.WireVector]:
316+
last = raw_result_mantissa[0]
317+
guard = raw_result_grs[2]
318+
round = raw_result_grs[1]
319+
sticky = raw_result_grs[0]
320+
round_up = guard & (last | round | sticky)
321+
raw_result_rounded_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits)
322+
raw_result_rounded_exponent = pyrtl.WireVector(bitwidth=num_exp_bits)
323+
rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1)
324+
with pyrtl.conditional_assignment:
325+
with round_up:
326+
with raw_result_mantissa == (1 << num_mant_bits) - 1:
327+
raw_result_rounded_mantissa |= 0
328+
raw_result_rounded_exponent |= raw_result_exponent + 1
329+
rounding_exponent_incremented |= 1
330+
with pyrtl.otherwise:
331+
raw_result_rounded_mantissa |= raw_result_mantissa + 1
332+
raw_result_rounded_exponent |= raw_result_exponent
333+
rounding_exponent_incremented |= 0
334+
with pyrtl.otherwise:
335+
raw_result_rounded_mantissa |= raw_result_mantissa
336+
raw_result_rounded_exponent |= raw_result_exponent
337+
rounding_exponent_incremented |= 0
338+
return (
339+
raw_result_rounded_exponent,
340+
raw_result_rounded_mantissa,
341+
rounding_exponent_incremented,
342+
)

0 commit comments

Comments
 (0)