|
| 1 | +import pyrtl |
| 2 | + |
| 3 | +from ._float_utills import FloatUtils |
| 4 | +from ._types import PyrtlFloatConfig, RoundingMode |
| 5 | + |
| 6 | + |
| 7 | +class AddSubHelper: |
| 8 | + @staticmethod |
| 9 | + def add( |
| 10 | + config: PyrtlFloatConfig, |
| 11 | + operand_a: pyrtl.WireVector, |
| 12 | + operand_b: pyrtl.WireVector, |
| 13 | + ) -> pyrtl.WireVector: |
| 14 | + fp_type_props = config.fp_type_properties |
| 15 | + rounding_mode = config.rounding_mode |
| 16 | + num_exp_bits = fp_type_props.num_exponent_bits |
| 17 | + num_mant_bits = fp_type_props.num_mantissa_bits |
| 18 | + total_bits = num_exp_bits + num_mant_bits + 1 |
| 19 | + |
| 20 | + operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a) |
| 21 | + operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b) |
| 22 | + |
| 23 | + # operand_smaller is the operand with the smaller absolute value and |
| 24 | + # operand_larger is the operand with the larger absolute value |
| 25 | + operand_smaller = pyrtl.WireVector(bitwidth=total_bits) |
| 26 | + operand_larger = pyrtl.WireVector(bitwidth=total_bits) |
| 27 | + |
| 28 | + with pyrtl.conditional_assignment: |
| 29 | + exponent_and_mantissa_len = num_mant_bits + num_exp_bits |
| 30 | + with ( |
| 31 | + operand_a_daz[:exponent_and_mantissa_len] |
| 32 | + < operand_b_daz[:exponent_and_mantissa_len] |
| 33 | + ): |
| 34 | + operand_smaller |= operand_a_daz |
| 35 | + operand_larger |= operand_b_daz |
| 36 | + with pyrtl.otherwise: |
| 37 | + operand_smaller |= operand_b_daz |
| 38 | + operand_larger |= operand_a_daz |
| 39 | + |
| 40 | + smaller_operand_sign = FloatUtils.get_sign(fp_type_props, operand_smaller) |
| 41 | + larger_operand_sign = FloatUtils.get_sign(fp_type_props, operand_larger) |
| 42 | + smaller_operand_exponent = FloatUtils.get_exponent( |
| 43 | + fp_type_props, operand_smaller |
| 44 | + ) |
| 45 | + larger_operand_exponent = FloatUtils.get_exponent(fp_type_props, operand_larger) |
| 46 | + smaller_operand_mantissa = pyrtl.concat( |
| 47 | + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_smaller) |
| 48 | + ) |
| 49 | + larger_operand_mantissa = pyrtl.concat( |
| 50 | + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_larger) |
| 51 | + ) |
| 52 | + |
| 53 | + exponent_diff = larger_operand_exponent - smaller_operand_exponent |
| 54 | + smaller_mantissa_shifted = pyrtl.shift_right_logical( |
| 55 | + smaller_operand_mantissa, exponent_diff |
| 56 | + ) |
| 57 | + grs = pyrtl.WireVector(bitwidth=3) # guard, round, sticky bits for rounding |
| 58 | + with pyrtl.conditional_assignment: |
| 59 | + with exponent_diff >= 2: |
| 60 | + guard_and_round = pyrtl.shift_right_logical( |
| 61 | + smaller_operand_mantissa, exponent_diff - 2 |
| 62 | + )[:2] |
| 63 | + mask = ( |
| 64 | + pyrtl.shift_left_logical( |
| 65 | + pyrtl.Const(1, bitwidth=num_mant_bits), exponent_diff - 2 |
| 66 | + ) |
| 67 | + - 1 |
| 68 | + ) |
| 69 | + sticky = (smaller_operand_mantissa & mask) != 0 |
| 70 | + grs |= pyrtl.concat(guard_and_round, sticky) |
| 71 | + with exponent_diff == 1: |
| 72 | + grs |= pyrtl.concat( |
| 73 | + smaller_operand_mantissa[0], pyrtl.Const(0, bitwidth=2) |
| 74 | + ) |
| 75 | + with pyrtl.otherwise: |
| 76 | + grs |= 0 |
| 77 | + smaller_mantissa_shifted_grs = pyrtl.concat(smaller_mantissa_shifted, grs) |
| 78 | + larger_mantissa_extended = pyrtl.concat( |
| 79 | + larger_operand_mantissa, pyrtl.Const(0, bitwidth=3) |
| 80 | + ) |
| 81 | + |
| 82 | + sum_exponent, sum_mantissa, sum_grs, sum_carry = AddSubHelper._add_operands( |
| 83 | + larger_operand_exponent, |
| 84 | + smaller_mantissa_shifted_grs, |
| 85 | + larger_mantissa_extended, |
| 86 | + ) |
| 87 | + |
| 88 | + sub_exponent, sub_mantissa, sub_grs, num_leading_zeros = ( |
| 89 | + AddSubHelper._sub_operands( |
| 90 | + num_mant_bits, |
| 91 | + larger_operand_exponent, |
| 92 | + smaller_mantissa_shifted_grs, |
| 93 | + larger_mantissa_extended, |
| 94 | + ) |
| 95 | + ) |
| 96 | + |
| 97 | + # WireVectors for the raw addition or subtraction result, before handling |
| 98 | + # special cases |
| 99 | + raw_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) |
| 100 | + raw_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) |
| 101 | + if rounding_mode == RoundingMode.RNE: |
| 102 | + raw_result_grs = pyrtl.WireVector(bitwidth=3) |
| 103 | + |
| 104 | + with pyrtl.conditional_assignment: |
| 105 | + with smaller_operand_sign == larger_operand_sign: # add |
| 106 | + raw_result_exponent |= sum_exponent |
| 107 | + raw_result_mantissa |= sum_mantissa |
| 108 | + if rounding_mode == RoundingMode.RNE: |
| 109 | + raw_result_grs |= sum_grs |
| 110 | + with pyrtl.otherwise: # sub |
| 111 | + raw_result_exponent |= sub_exponent |
| 112 | + raw_result_mantissa |= sub_mantissa |
| 113 | + if rounding_mode == RoundingMode.RNE: |
| 114 | + raw_result_grs |= sub_grs |
| 115 | + |
| 116 | + if rounding_mode == RoundingMode.RNE: |
| 117 | + ( |
| 118 | + raw_result_rounded_exponent, |
| 119 | + raw_result_rounded_mantissa, |
| 120 | + rounding_exponent_incremented, |
| 121 | + ) = AddSubHelper._round( |
| 122 | + num_mant_bits, |
| 123 | + num_exp_bits, |
| 124 | + raw_result_exponent, |
| 125 | + raw_result_mantissa, |
| 126 | + raw_result_grs, |
| 127 | + ) |
| 128 | + |
| 129 | + smaller_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_smaller) |
| 130 | + larger_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_larger) |
| 131 | + smaller_operand_inf = FloatUtils.is_inf(fp_type_props, operand_smaller) |
| 132 | + larger_operand_inf = FloatUtils.is_inf(fp_type_props, operand_larger) |
| 133 | + smaller_operand_zero = FloatUtils.is_zero(fp_type_props, operand_smaller) |
| 134 | + larger_operand_zero = FloatUtils.is_zero(fp_type_props, operand_larger) |
| 135 | + |
| 136 | + # WireVectors for the final result after handling special cases |
| 137 | + final_result_sign = pyrtl.WireVector(bitwidth=1) |
| 138 | + final_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) |
| 139 | + final_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) |
| 140 | + |
| 141 | + # handle special cases |
| 142 | + with pyrtl.conditional_assignment: |
| 143 | + # if either operand is NaN or both operands are infinity of opposite signs, |
| 144 | + # the result is NaN |
| 145 | + with ( |
| 146 | + smaller_operand_nan |
| 147 | + | larger_operand_nan |
| 148 | + | ( |
| 149 | + smaller_operand_inf |
| 150 | + & larger_operand_inf |
| 151 | + & (larger_operand_sign != smaller_operand_sign) |
| 152 | + ) |
| 153 | + ): |
| 154 | + final_result_sign |= larger_operand_sign |
| 155 | + FloatUtils.make_output_NaN( |
| 156 | + fp_type_props, final_result_exponent, final_result_mantissa |
| 157 | + ) |
| 158 | + # infinities |
| 159 | + with smaller_operand_inf: |
| 160 | + final_result_sign |= larger_operand_sign |
| 161 | + FloatUtils.make_output_inf( |
| 162 | + fp_type_props, final_result_exponent, final_result_mantissa |
| 163 | + ) |
| 164 | + with larger_operand_inf: |
| 165 | + final_result_sign |= larger_operand_sign |
| 166 | + FloatUtils.make_output_inf( |
| 167 | + fp_type_props, final_result_exponent, final_result_mantissa |
| 168 | + ) |
| 169 | + # +num + -num = +0 |
| 170 | + with ( |
| 171 | + (smaller_operand_mantissa == larger_operand_mantissa) |
| 172 | + & (smaller_operand_exponent == larger_operand_exponent) |
| 173 | + & (larger_operand_sign != smaller_operand_sign) |
| 174 | + ): |
| 175 | + final_result_sign |= 0 |
| 176 | + FloatUtils.make_output_zero( |
| 177 | + final_result_exponent, final_result_mantissa |
| 178 | + ) |
| 179 | + with smaller_operand_zero: |
| 180 | + final_result_sign |= larger_operand_sign |
| 181 | + final_result_mantissa |= larger_operand_mantissa |
| 182 | + final_result_exponent |= larger_operand_exponent |
| 183 | + with larger_operand_zero: |
| 184 | + final_result_sign |= smaller_operand_sign |
| 185 | + final_result_mantissa |= smaller_operand_mantissa |
| 186 | + final_result_exponent |= smaller_operand_exponent |
| 187 | + # overflow and underflow |
| 188 | + initial_larger_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2) |
| 189 | + if rounding_mode == RoundingMode.RNE: |
| 190 | + larger_exponent_max_value = ( |
| 191 | + initial_larger_exponent_max_value |
| 192 | + - sum_carry |
| 193 | + - rounding_exponent_incremented |
| 194 | + ) |
| 195 | + else: |
| 196 | + larger_exponent_max_value = ( |
| 197 | + initial_larger_exponent_max_value - sum_carry |
| 198 | + ) |
| 199 | + initial_larger_exponent_min_value = pyrtl.Const(1) |
| 200 | + if rounding_mode == RoundingMode.RNE: |
| 201 | + larger_exponent_min_value = ( |
| 202 | + initial_larger_exponent_min_value |
| 203 | + + num_leading_zeros |
| 204 | + - rounding_exponent_incremented |
| 205 | + ) |
| 206 | + else: |
| 207 | + larger_exponent_min_value = ( |
| 208 | + initial_larger_exponent_min_value + num_leading_zeros |
| 209 | + ) |
| 210 | + with (smaller_operand_sign == larger_operand_sign) & ( |
| 211 | + larger_operand_exponent > larger_exponent_max_value |
| 212 | + ): # detect overflow on addition |
| 213 | + final_result_sign |= larger_operand_sign |
| 214 | + if rounding_mode == RoundingMode.RNE: |
| 215 | + FloatUtils.make_output_inf( |
| 216 | + fp_type_props, final_result_exponent, final_result_mantissa |
| 217 | + ) |
| 218 | + else: |
| 219 | + FloatUtils.make_output_largest_finite_number( |
| 220 | + fp_type_props, final_result_exponent, final_result_mantissa |
| 221 | + ) |
| 222 | + with (smaller_operand_sign != larger_operand_sign) & ( |
| 223 | + larger_operand_exponent < larger_exponent_min_value |
| 224 | + ): # detect underflow on subtraction |
| 225 | + final_result_sign |= larger_operand_sign |
| 226 | + FloatUtils.make_output_zero( |
| 227 | + final_result_exponent, final_result_mantissa |
| 228 | + ) |
| 229 | + with pyrtl.otherwise: |
| 230 | + final_result_sign |= larger_operand_sign |
| 231 | + if rounding_mode == RoundingMode.RNE: |
| 232 | + final_result_exponent |= raw_result_rounded_exponent |
| 233 | + final_result_mantissa |= raw_result_rounded_mantissa |
| 234 | + else: |
| 235 | + final_result_exponent |= raw_result_exponent |
| 236 | + final_result_mantissa |= raw_result_mantissa |
| 237 | + |
| 238 | + return pyrtl.concat( |
| 239 | + final_result_sign, final_result_exponent, final_result_mantissa |
| 240 | + ) |
| 241 | + |
| 242 | + @staticmethod |
| 243 | + def sub( |
| 244 | + config: PyrtlFloatConfig, |
| 245 | + operand_a: pyrtl.WireVector, |
| 246 | + operand_b: pyrtl.WireVector, |
| 247 | + ) -> pyrtl.WireVector: |
| 248 | + num_exp_bits = config.fp_type_properties.num_exponent_bits |
| 249 | + num_mant_bits = config.fp_type_properties.num_mantissa_bits |
| 250 | + operand_b_negated = operand_b ^ pyrtl.concat( |
| 251 | + pyrtl.Const(1, bitwidth=1), |
| 252 | + pyrtl.Const(0, bitwidth=num_exp_bits + num_mant_bits), |
| 253 | + ) |
| 254 | + return AddSubHelper.add(config, operand_a, operand_b_negated) |
| 255 | + |
| 256 | + @staticmethod |
| 257 | + def _add_operands( |
| 258 | + larger_operand_exponent: pyrtl.WireVector, |
| 259 | + smaller_mantissa_shifted_grs: pyrtl.WireVector, |
| 260 | + larger_mantissa_extended: pyrtl.WireVector, |
| 261 | + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: |
| 262 | + sum_mantissa_grs = pyrtl.WireVector() |
| 263 | + sum_mantissa_grs <<= larger_mantissa_extended + smaller_mantissa_shifted_grs |
| 264 | + sum_carry = sum_mantissa_grs[-1] |
| 265 | + sum_mantissa = pyrtl.select( |
| 266 | + sum_carry, sum_mantissa_grs[4:], sum_mantissa_grs[3:-1] |
| 267 | + ) |
| 268 | + sum_grs = pyrtl.select( |
| 269 | + sum_carry, |
| 270 | + pyrtl.concat(sum_mantissa_grs[2:4], sum_mantissa_grs[:2] != 0), |
| 271 | + sum_mantissa_grs[:3], |
| 272 | + ) |
| 273 | + sum_exponent = pyrtl.select( |
| 274 | + sum_carry, larger_operand_exponent + 1, larger_operand_exponent |
| 275 | + ) |
| 276 | + return sum_exponent, sum_mantissa, sum_grs, sum_carry |
| 277 | + |
| 278 | + @staticmethod |
| 279 | + def _sub_operands( |
| 280 | + num_mant_bits: int, |
| 281 | + larger_operand_exponent: pyrtl.WireVector, |
| 282 | + smaller_mantissa_shifted_grs: pyrtl.WireVector, |
| 283 | + larger_mantissa_extended: pyrtl.WireVector, |
| 284 | + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: |
| 285 | + def leading_zero_priority_encoder(wire: pyrtl.WireVector, length: int): |
| 286 | + out = pyrtl.WireVector( |
| 287 | + bitwidth=pyrtl.infer_val_and_bitwidth(length - 1).bitwidth |
| 288 | + ) |
| 289 | + with pyrtl.conditional_assignment: |
| 290 | + for i in range(wire.bitwidth - 1, wire.bitwidth - length - 1, -1): |
| 291 | + with wire[i]: |
| 292 | + out |= wire.bitwidth - i - 1 |
| 293 | + return out |
| 294 | + |
| 295 | + sub_mantissa_grs = pyrtl.WireVector(bitwidth=num_mant_bits + 4) |
| 296 | + sub_mantissa_grs <<= larger_mantissa_extended - smaller_mantissa_shifted_grs |
| 297 | + num_leading_zeros = leading_zero_priority_encoder( |
| 298 | + sub_mantissa_grs, num_mant_bits + 1 |
| 299 | + ) |
| 300 | + sub_mantissa_grs_shifted = pyrtl.shift_left_logical( |
| 301 | + sub_mantissa_grs, num_leading_zeros |
| 302 | + ) |
| 303 | + sub_mantissa = sub_mantissa_grs_shifted[3:] |
| 304 | + sub_grs = sub_mantissa_grs_shifted[:3] |
| 305 | + sub_exponent = larger_operand_exponent - num_leading_zeros |
| 306 | + return sub_exponent, sub_mantissa, sub_grs, num_leading_zeros |
| 307 | + |
| 308 | + @staticmethod |
| 309 | + def _round( |
| 310 | + num_mant_bits: int, |
| 311 | + num_exp_bits: int, |
| 312 | + raw_result_exponent: pyrtl.WireVector, |
| 313 | + raw_result_mantissa: pyrtl.WireVector, |
| 314 | + raw_result_grs: pyrtl.WireVector, |
| 315 | + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: |
| 316 | + last = raw_result_mantissa[0] |
| 317 | + guard = raw_result_grs[2] |
| 318 | + round = raw_result_grs[1] |
| 319 | + sticky = raw_result_grs[0] |
| 320 | + round_up = guard & (last | round | sticky) |
| 321 | + raw_result_rounded_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) |
| 322 | + raw_result_rounded_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) |
| 323 | + rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1) |
| 324 | + with pyrtl.conditional_assignment: |
| 325 | + with round_up: |
| 326 | + with raw_result_mantissa == (1 << num_mant_bits) - 1: |
| 327 | + raw_result_rounded_mantissa |= 0 |
| 328 | + raw_result_rounded_exponent |= raw_result_exponent + 1 |
| 329 | + rounding_exponent_incremented |= 1 |
| 330 | + with pyrtl.otherwise: |
| 331 | + raw_result_rounded_mantissa |= raw_result_mantissa + 1 |
| 332 | + raw_result_rounded_exponent |= raw_result_exponent |
| 333 | + rounding_exponent_incremented |= 0 |
| 334 | + with pyrtl.otherwise: |
| 335 | + raw_result_rounded_mantissa |= raw_result_mantissa |
| 336 | + raw_result_rounded_exponent |= raw_result_exponent |
| 337 | + rounding_exponent_incremented |= 0 |
| 338 | + return ( |
| 339 | + raw_result_rounded_exponent, |
| 340 | + raw_result_rounded_mantissa, |
| 341 | + rounding_exponent_incremented, |
| 342 | + ) |
0 commit comments