From 94c4c68531c3eb29de726407ba9da6b8bab52b1a Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Wed, 29 Oct 2025 22:30:52 -0700 Subject: [PATCH 1/3] Pyrtl floating point library --- pyrtl/rtllib/pyrtlfloat/__init__.py | 12 + pyrtl/rtllib/pyrtlfloat/_add_sub.py | 342 ++++++++++++++++++ pyrtl/rtllib/pyrtlfloat/_float_utills.py | 104 ++++++ pyrtl/rtllib/pyrtlfloat/_multiplication.py | 161 +++++++++ pyrtl/rtllib/pyrtlfloat/_types.py | 30 ++ pyrtl/rtllib/pyrtlfloat/floatoperations.py | 33 ++ pyrtl/rtllib/pyrtlfloat/floatwirevector.py | 46 +++ tests/rtllib/pyrtlfloat/test_add_sub.py | 30 ++ .../rtllib/pyrtlfloat/test_multiplication.py | 31 ++ 9 files changed, 789 insertions(+) create mode 100644 pyrtl/rtllib/pyrtlfloat/__init__.py create mode 100644 pyrtl/rtllib/pyrtlfloat/_add_sub.py create mode 100644 pyrtl/rtllib/pyrtlfloat/_float_utills.py create mode 100644 pyrtl/rtllib/pyrtlfloat/_multiplication.py create mode 100644 pyrtl/rtllib/pyrtlfloat/_types.py create mode 100644 pyrtl/rtllib/pyrtlfloat/floatoperations.py create mode 100644 pyrtl/rtllib/pyrtlfloat/floatwirevector.py create mode 100644 tests/rtllib/pyrtlfloat/test_add_sub.py create mode 100644 tests/rtllib/pyrtlfloat/test_multiplication.py diff --git a/pyrtl/rtllib/pyrtlfloat/__init__.py b/pyrtl/rtllib/pyrtlfloat/__init__.py new file mode 100644 index 00000000..df407d93 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/__init__.py @@ -0,0 +1,12 @@ +from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode +from .floatoperations import FloatOperations +from .floatwirevector import Float16WireVector + +__all__ = [ + "FloatingPointType", + "FPTypeProperties", + "PyrtlFloatConfig", + "RoundingMode", + "FloatOperations", + "Float16WireVector", +] diff --git a/pyrtl/rtllib/pyrtlfloat/_add_sub.py b/pyrtl/rtllib/pyrtlfloat/_add_sub.py new file mode 100644 index 00000000..419d647f --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_add_sub.py @@ -0,0 +1,342 @@ +import pyrtl + +from ._float_utills import FloatUtils +from ._types import PyrtlFloatConfig, RoundingMode + + +class AddSubHelper: + @staticmethod + def add( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + total_bits = num_exp_bits + num_mant_bits + 1 + + operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a) + operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b) + + # operand_smaller is the operand with the smaller absolute value and + # operand_larger is the operand with the larger absolute value + operand_smaller = pyrtl.WireVector(bitwidth=total_bits) + operand_larger = pyrtl.WireVector(bitwidth=total_bits) + + with pyrtl.conditional_assignment: + exponent_and_mantissa_len = num_mant_bits + num_exp_bits + with ( + operand_a_daz[:exponent_and_mantissa_len] + < operand_b_daz[:exponent_and_mantissa_len] + ): + operand_smaller |= operand_a_daz + operand_larger |= operand_b_daz + with pyrtl.otherwise: + operand_smaller |= operand_b_daz + operand_larger |= operand_a_daz + + smaller_operand_sign = FloatUtils.get_sign(fp_type_props, operand_smaller) + larger_operand_sign = FloatUtils.get_sign(fp_type_props, operand_larger) + smaller_operand_exponent = FloatUtils.get_exponent( + fp_type_props, operand_smaller + ) + larger_operand_exponent = FloatUtils.get_exponent(fp_type_props, operand_larger) + smaller_operand_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_smaller) + ) + larger_operand_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_larger) + ) + + exponent_diff = larger_operand_exponent - smaller_operand_exponent + smaller_mantissa_shifted = pyrtl.shift_right_logical( + smaller_operand_mantissa, exponent_diff + ) + grs = pyrtl.WireVector(bitwidth=3) # guard, round, sticky bits for rounding + with pyrtl.conditional_assignment: + with exponent_diff >= 2: + guard_and_round = pyrtl.shift_right_logical( + smaller_operand_mantissa, exponent_diff - 2 + )[:2] + mask = ( + pyrtl.shift_left_logical( + pyrtl.Const(1, bitwidth=num_mant_bits), exponent_diff - 2 + ) + - 1 + ) + sticky = (smaller_operand_mantissa & mask) != 0 + grs |= pyrtl.concat(guard_and_round, sticky) + with exponent_diff == 1: + grs |= pyrtl.concat( + smaller_operand_mantissa[0], pyrtl.Const(0, bitwidth=2) + ) + with pyrtl.otherwise: + grs |= 0 + smaller_mantissa_shifted_grs = pyrtl.concat(smaller_mantissa_shifted, grs) + larger_mantissa_extended = pyrtl.concat( + larger_operand_mantissa, pyrtl.Const(0, bitwidth=3) + ) + + sum_exponent, sum_mantissa, sum_grs, sum_carry = AddSubHelper._add_operands( + larger_operand_exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, + ) + + sub_exponent, sub_mantissa, sub_grs, num_leading_zeros = ( + AddSubHelper._sub_operands( + num_mant_bits, + larger_operand_exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, + ) + ) + + # WireVectors for the raw addition or subtraction result, before handling + # special cases + raw_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + raw_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + if rounding_mode == RoundingMode.RNE: + raw_result_grs = pyrtl.WireVector(bitwidth=3) + + with pyrtl.conditional_assignment: + with smaller_operand_sign == larger_operand_sign: # add + raw_result_exponent |= sum_exponent + raw_result_mantissa |= sum_mantissa + if rounding_mode == RoundingMode.RNE: + raw_result_grs |= sum_grs + with pyrtl.otherwise: # sub + raw_result_exponent |= sub_exponent + raw_result_mantissa |= sub_mantissa + if rounding_mode == RoundingMode.RNE: + raw_result_grs |= sub_grs + + if rounding_mode == RoundingMode.RNE: + ( + raw_result_rounded_exponent, + raw_result_rounded_mantissa, + rounding_exponent_incremented, + ) = AddSubHelper._round( + num_mant_bits, + num_exp_bits, + raw_result_exponent, + raw_result_mantissa, + raw_result_grs, + ) + + smaller_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_smaller) + larger_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_larger) + smaller_operand_inf = FloatUtils.is_inf(fp_type_props, operand_smaller) + larger_operand_inf = FloatUtils.is_inf(fp_type_props, operand_larger) + smaller_operand_zero = FloatUtils.is_zero(fp_type_props, operand_smaller) + larger_operand_zero = FloatUtils.is_zero(fp_type_props, operand_larger) + + # WireVectors for the final result after handling special cases + final_result_sign = pyrtl.WireVector(bitwidth=1) + final_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + final_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + + # handle special cases + with pyrtl.conditional_assignment: + # if either operand is NaN or both operands are infinity of opposite signs, + # the result is NaN + with ( + smaller_operand_nan + | larger_operand_nan + | ( + smaller_operand_inf + & larger_operand_inf + & (larger_operand_sign != smaller_operand_sign) + ) + ): + final_result_sign |= larger_operand_sign + FloatUtils.make_output_NaN( + fp_type_props, final_result_exponent, final_result_mantissa + ) + # infinities + with smaller_operand_inf: + final_result_sign |= larger_operand_sign + FloatUtils.make_output_inf( + fp_type_props, final_result_exponent, final_result_mantissa + ) + with larger_operand_inf: + final_result_sign |= larger_operand_sign + FloatUtils.make_output_inf( + fp_type_props, final_result_exponent, final_result_mantissa + ) + # +num + -num = +0 + with ( + (smaller_operand_mantissa == larger_operand_mantissa) + & (smaller_operand_exponent == larger_operand_exponent) + & (larger_operand_sign != smaller_operand_sign) + ): + final_result_sign |= 0 + FloatUtils.make_output_zero( + final_result_exponent, final_result_mantissa + ) + with smaller_operand_zero: + final_result_sign |= larger_operand_sign + final_result_mantissa |= larger_operand_mantissa + final_result_exponent |= larger_operand_exponent + with larger_operand_zero: + final_result_sign |= smaller_operand_sign + final_result_mantissa |= smaller_operand_mantissa + final_result_exponent |= smaller_operand_exponent + # overflow and underflow + initial_larger_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2) + if rounding_mode == RoundingMode.RNE: + larger_exponent_max_value = ( + initial_larger_exponent_max_value + - sum_carry + - rounding_exponent_incremented + ) + else: + larger_exponent_max_value = ( + initial_larger_exponent_max_value - sum_carry + ) + initial_larger_exponent_min_value = pyrtl.Const(1) + if rounding_mode == RoundingMode.RNE: + larger_exponent_min_value = ( + initial_larger_exponent_min_value + + num_leading_zeros + - rounding_exponent_incremented + ) + else: + larger_exponent_min_value = ( + initial_larger_exponent_min_value + num_leading_zeros + ) + with (smaller_operand_sign == larger_operand_sign) & ( + larger_operand_exponent > larger_exponent_max_value + ): # detect overflow on addition + final_result_sign |= larger_operand_sign + if rounding_mode == RoundingMode.RNE: + FloatUtils.make_output_inf( + fp_type_props, final_result_exponent, final_result_mantissa + ) + else: + FloatUtils.make_output_largest_finite_number( + fp_type_props, final_result_exponent, final_result_mantissa + ) + with (smaller_operand_sign != larger_operand_sign) & ( + larger_operand_exponent < larger_exponent_min_value + ): # detect underflow on subtraction + final_result_sign |= larger_operand_sign + FloatUtils.make_output_zero( + final_result_exponent, final_result_mantissa + ) + with pyrtl.otherwise: + final_result_sign |= larger_operand_sign + if rounding_mode == RoundingMode.RNE: + final_result_exponent |= raw_result_rounded_exponent + final_result_mantissa |= raw_result_rounded_mantissa + else: + final_result_exponent |= raw_result_exponent + final_result_mantissa |= raw_result_mantissa + + return pyrtl.concat( + final_result_sign, final_result_exponent, final_result_mantissa + ) + + @staticmethod + def sub( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + num_exp_bits = config.fp_type_properties.num_exponent_bits + num_mant_bits = config.fp_type_properties.num_mantissa_bits + operand_b_negated = operand_b ^ pyrtl.concat( + pyrtl.Const(1, bitwidth=1), + pyrtl.Const(0, bitwidth=num_exp_bits + num_mant_bits), + ) + return AddSubHelper.add(config, operand_a, operand_b_negated) + + @staticmethod + def _add_operands( + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: + sum_mantissa_grs = pyrtl.WireVector() + sum_mantissa_grs <<= larger_mantissa_extended + smaller_mantissa_shifted_grs + sum_carry = sum_mantissa_grs[-1] + sum_mantissa = pyrtl.select( + sum_carry, sum_mantissa_grs[4:], sum_mantissa_grs[3:-1] + ) + sum_grs = pyrtl.select( + sum_carry, + pyrtl.concat(sum_mantissa_grs[2:4], sum_mantissa_grs[:2] != 0), + sum_mantissa_grs[:3], + ) + sum_exponent = pyrtl.select( + sum_carry, larger_operand_exponent + 1, larger_operand_exponent + ) + return sum_exponent, sum_mantissa, sum_grs, sum_carry + + @staticmethod + def _sub_operands( + num_mant_bits: int, + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: + def leading_zero_priority_encoder(wire: pyrtl.WireVector, length: int): + out = pyrtl.WireVector( + bitwidth=pyrtl.infer_val_and_bitwidth(length - 1).bitwidth + ) + with pyrtl.conditional_assignment: + for i in range(wire.bitwidth - 1, wire.bitwidth - length - 1, -1): + with wire[i]: + out |= wire.bitwidth - i - 1 + return out + + sub_mantissa_grs = pyrtl.WireVector(bitwidth=num_mant_bits + 4) + sub_mantissa_grs <<= larger_mantissa_extended - smaller_mantissa_shifted_grs + num_leading_zeros = leading_zero_priority_encoder( + sub_mantissa_grs, num_mant_bits + 1 + ) + sub_mantissa_grs_shifted = pyrtl.shift_left_logical( + sub_mantissa_grs, num_leading_zeros + ) + sub_mantissa = sub_mantissa_grs_shifted[3:] + sub_grs = sub_mantissa_grs_shifted[:3] + sub_exponent = larger_operand_exponent - num_leading_zeros + return sub_exponent, sub_mantissa, sub_grs, num_leading_zeros + + @staticmethod + def _round( + num_mant_bits: int, + num_exp_bits: int, + raw_result_exponent: pyrtl.WireVector, + raw_result_mantissa: pyrtl.WireVector, + raw_result_grs: pyrtl.WireVector, + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: + last = raw_result_mantissa[0] + guard = raw_result_grs[2] + round = raw_result_grs[1] + sticky = raw_result_grs[0] + round_up = guard & (last | round | sticky) + raw_result_rounded_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + raw_result_rounded_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1) + with pyrtl.conditional_assignment: + with round_up: + with raw_result_mantissa == (1 << num_mant_bits) - 1: + raw_result_rounded_mantissa |= 0 + raw_result_rounded_exponent |= raw_result_exponent + 1 + rounding_exponent_incremented |= 1 + with pyrtl.otherwise: + raw_result_rounded_mantissa |= raw_result_mantissa + 1 + raw_result_rounded_exponent |= raw_result_exponent + rounding_exponent_incremented |= 0 + with pyrtl.otherwise: + raw_result_rounded_mantissa |= raw_result_mantissa + raw_result_rounded_exponent |= raw_result_exponent + rounding_exponent_incremented |= 0 + return ( + raw_result_rounded_exponent, + raw_result_rounded_mantissa, + rounding_exponent_incremented, + ) diff --git a/pyrtl/rtllib/pyrtlfloat/_float_utills.py b/pyrtl/rtllib/pyrtlfloat/_float_utills.py new file mode 100644 index 00000000..0ae58329 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_float_utills.py @@ -0,0 +1,104 @@ +import pyrtl + +from ._types import FPTypeProperties + + +class FloatUtils: + @staticmethod + def get_sign(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return wire[fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits] + + @staticmethod + def get_exponent( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + return wire[ + fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits + + fp_prop.num_exponent_bits + ] + + @staticmethod + def get_mantissa( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + return wire[: fp_prop.num_mantissa_bits] + + @staticmethod + def is_zero(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) == 0) & ( + FloatUtils.get_exponent(fp_prop, wire) == 0 + ) + + @staticmethod + def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) == 0) & ( + FloatUtils.get_exponent(fp_prop, wire) + == (1 << fp_prop.num_exponent_bits) - 1 + ) + + @staticmethod + def is_denormalized( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) != 0) & ( + FloatUtils.get_exponent(fp_prop, wire) == 0 + ) + + @staticmethod + def is_NaN(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) != 0) & ( + FloatUtils.get_exponent(fp_prop, wire) + == (1 << fp_prop.num_exponent_bits) - 1 + ) + + @staticmethod + def make_denormals_zero( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + out = pyrtl.WireVector( + bitwidth=fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits + 1 + ) + with pyrtl.conditional_assignment: + with FloatUtils.get_exponent(fp_prop, wire) == 0: + out |= pyrtl.concat( + FloatUtils.get_sign(fp_prop, wire), + FloatUtils.get_exponent(fp_prop, wire), + pyrtl.Const(0, bitwidth=fp_prop.num_mantissa_bits), + ) + with pyrtl.otherwise: + out |= wire + return out + + @staticmethod + def make_output_inf( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ) -> None: + exponent |= (1 << fp_prop.num_exponent_bits) - 1 + mantissa |= 0 + + @staticmethod + def make_output_NaN( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ) -> None: + exponent |= (1 << fp_prop.num_exponent_bits) - 1 + mantissa |= 1 << (fp_prop.num_mantissa_bits - 1) + + @staticmethod + def make_output_zero( + exponent: pyrtl.WireVector, mantissa: pyrtl.WireVector + ) -> None: + exponent |= 0 + mantissa |= 0 + + @staticmethod + def make_output_largest_finite_number( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ) -> None: + exponent |= (1 << fp_prop.num_exponent_bits) - 2 + mantissa |= (1 << fp_prop.num_mantissa_bits) - 1 diff --git a/pyrtl/rtllib/pyrtlfloat/_multiplication.py b/pyrtl/rtllib/pyrtlfloat/_multiplication.py new file mode 100644 index 00000000..a0de7d25 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_multiplication.py @@ -0,0 +1,161 @@ +import pyrtl + +from ._float_utills import FloatUtils +from ._types import PyrtlFloatConfig, RoundingMode + + +class MultiplicationHelper: + @staticmethod + def multiply( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + + operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a) + operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b) + a_sign = FloatUtils.get_sign(fp_type_props, operand_a_daz) + b_sign = FloatUtils.get_sign(fp_type_props, operand_b_daz) + a_exponent = FloatUtils.get_exponent(fp_type_props, operand_a_daz) + b_exponent = FloatUtils.get_exponent(fp_type_props, operand_b_daz) + + exponent_bias = 2 ** (fp_type_props.num_exponent_bits - 1) - 1 + + result_sign = a_sign ^ b_sign + operand_exponent_sums = a_exponent + b_exponent + product_exponent = operand_exponent_sums - pyrtl.Const(exponent_bias) + + a_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_a_daz) + ) + b_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_b_daz) + ) + product_mantissa = a_mantissa * b_mantissa + + normalized_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) + normalized_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + + need_to_normalize = product_mantissa[-1] + + if rounding_mode == RoundingMode.RNE: + guard = pyrtl.WireVector(bitwidth=1) + sticky = pyrtl.WireVector(bitwidth=1) + last = pyrtl.WireVector(bitwidth=1) + + with pyrtl.conditional_assignment: + with need_to_normalize: + normalized_product_mantissa |= product_mantissa[-num_mant_bits - 1 :] + normalized_product_exponent |= product_exponent + 1 + if rounding_mode == RoundingMode.RNE: + guard |= product_mantissa[-num_mant_bits - 2] + sticky |= product_mantissa[: -num_mant_bits - 2] != 0 + last |= product_mantissa[-num_mant_bits - 1] + with pyrtl.otherwise: + normalized_product_mantissa |= product_mantissa[-num_mant_bits - 2 : -1] + normalized_product_exponent |= product_exponent + if rounding_mode == RoundingMode.RNE: + guard |= product_mantissa[-num_mant_bits - 3] + sticky |= product_mantissa[: -num_mant_bits - 3] != 0 + last |= product_mantissa[-num_mant_bits - 2] + + if rounding_mode == RoundingMode.RNE: + rounded_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + rounded_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) + exponent_incremented = pyrtl.WireVector(bitwidth=1) + with pyrtl.conditional_assignment: + with guard & (last | sticky): + with normalized_product_mantissa == (1 << num_mant_bits) - 1: + rounded_product_mantissa |= 0 + rounded_product_exponent |= normalized_product_exponent + 1 + exponent_incremented |= 1 + with pyrtl.otherwise: + rounded_product_mantissa |= normalized_product_mantissa + 1 + rounded_product_exponent |= normalized_product_exponent + exponent_incremented |= 0 + with pyrtl.otherwise: + rounded_product_mantissa |= normalized_product_mantissa + rounded_product_exponent |= normalized_product_exponent + exponent_incremented |= 0 + + result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + + operand_a_nan = FloatUtils.is_NaN(fp_type_props, operand_a_daz) + operand_b_nan = FloatUtils.is_NaN(fp_type_props, operand_b_daz) + operand_a_inf = FloatUtils.is_inf(fp_type_props, operand_a_daz) + operand_b_inf = FloatUtils.is_inf(fp_type_props, operand_b_daz) + operand_a_zero = FloatUtils.is_zero(fp_type_props, operand_a_daz) + operand_b_zero = FloatUtils.is_zero(fp_type_props, operand_b_daz) + operand_a_denormalized = FloatUtils.is_denormalized( + fp_type_props, operand_a_daz + ) + operand_b_denormalized = FloatUtils.is_denormalized( + fp_type_props, operand_b_daz + ) + + # Overflow and underflow checks (only for normal cases) + sum_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2 + exponent_bias) + sum_exponent_min_value = pyrtl.Const(1 + exponent_bias) + if rounding_mode == RoundingMode.RNE: + exponent_max_value = ( + sum_exponent_max_value - need_to_normalize - exponent_incremented + ) + exponent_min_value = ( + sum_exponent_min_value - need_to_normalize - exponent_incremented + ) + else: + exponent_max_value = sum_exponent_max_value - need_to_normalize + exponent_min_value = sum_exponent_min_value - need_to_normalize + + if rounding_mode == RoundingMode.RNE: + raw_result_exponent = rounded_product_exponent[0:num_exp_bits] + raw_result_mantissa = rounded_product_mantissa + else: + raw_result_exponent = normalized_product_exponent[0:num_exp_bits] + raw_result_mantissa = normalized_product_mantissa + + with pyrtl.conditional_assignment: + # nan + with ( + operand_a_nan + | operand_b_nan + | (operand_a_inf & operand_b_zero) + | (operand_a_zero & operand_b_inf) + ): + FloatUtils.make_output_NaN( + fp_type_props, result_exponent, result_mantissa + ) + # infinity + with operand_a_inf | operand_b_inf: + FloatUtils.make_output_inf( + fp_type_props, result_exponent, result_mantissa + ) + # overflow + with operand_exponent_sums > exponent_max_value: + if rounding_mode == RoundingMode.RNE: + FloatUtils.make_output_inf( + fp_type_props, result_exponent, result_mantissa + ) + else: + FloatUtils.make_output_largest_finite_number( + fp_type_props, result_exponent, result_mantissa + ) + # zero or underflow + with ( + operand_a_zero + | operand_b_zero + | (operand_exponent_sums < exponent_min_value) + | operand_a_denormalized + | operand_b_denormalized + ): + FloatUtils.make_output_zero(result_exponent, result_mantissa) + with pyrtl.otherwise: + result_exponent |= raw_result_exponent + result_mantissa |= raw_result_mantissa + + return pyrtl.concat(result_sign, result_exponent, result_mantissa) diff --git a/pyrtl/rtllib/pyrtlfloat/_types.py b/pyrtl/rtllib/pyrtlfloat/_types.py new file mode 100644 index 00000000..15a1c811 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_types.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from enum import Enum + + +class RoundingMode(Enum): + RTZ = 1 # round towards zero (truncate) + RNE = 2 # round to nearest, ties to even (default mode) + + +@dataclass(frozen=True) +class FPTypeProperties: + num_exponent_bits: int + num_mantissa_bits: int + + +class FloatingPointType(Enum): + BFLOAT16 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=7) + FLOAT16 = FPTypeProperties(num_exponent_bits=5, num_mantissa_bits=10) + FLOAT32 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=23) + FLOAT64 = FPTypeProperties(num_exponent_bits=11, num_mantissa_bits=52) + + +@dataclass(frozen=True) +class PyrtlFloatConfig: + fp_type_properties: FPTypeProperties + rounding_mode: RoundingMode + + +class PyrtlFloatException(Exception): + pass diff --git a/pyrtl/rtllib/pyrtlfloat/floatoperations.py b/pyrtl/rtllib/pyrtlfloat/floatoperations.py new file mode 100644 index 00000000..e4d10769 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/floatoperations.py @@ -0,0 +1,33 @@ +import pyrtl + +from ._add_sub import AddSubHelper +from ._multiplication import MultiplicationHelper +from ._types import PyrtlFloatConfig, RoundingMode + + +class FloatOperations: + default_rounding_mode = RoundingMode.RNE + + @staticmethod + def multiply( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + return MultiplicationHelper.multiply(config, operand_a, operand_b) + + @staticmethod + def add( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + return AddSubHelper.add(config, operand_a, operand_b) + + @staticmethod + def sub( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + return AddSubHelper.sub(config, operand_a, operand_b) diff --git a/pyrtl/rtllib/pyrtlfloat/floatwirevector.py b/pyrtl/rtllib/pyrtlfloat/floatwirevector.py new file mode 100644 index 00000000..0004adb7 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/floatwirevector.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import pyrtl + +from ._types import FloatingPointType, PyrtlFloatConfig, PyrtlFloatException +from .floatoperations import FloatOperations + + +class Float16WireVector(pyrtl.WireVector): + def __init__(self): + super().__init__() + self.bitwidth = 16 + + def __ilshift__(self, other): + if isinstance(other, (pyrtl.WireVector, Float16WireVector)): + super().__ilshift__(other) + else: + msg = ( + "FloatWireVector16 can only be driven by a FloatWireVector16 " + "or a PyRTL WireVector." + ) + raise PyrtlFloatException(msg) + return self + + def _get_config(self) -> PyrtlFloatConfig | None: + return PyrtlFloatConfig( + FloatingPointType.FLOAT16.value, FloatOperations.default_rounding_mode + ) + + def __add__(self, other: Float16WireVector) -> Float16WireVector: + ret = Float16WireVector() + ret <<= FloatOperations.add(self._get_config(), self, other) + return ret + + def __sub__(self, other: Float16WireVector) -> Float16WireVector: + ret = Float16WireVector() + ret <<= FloatOperations.sub(self._get_config(), self, other) + return ret + + def __mul__(self, other: Float16WireVector) -> Float16WireVector: + ret = Float16WireVector() + ret <<= FloatOperations.multiply(self._get_config(), self, other) + return ret + + +# will create BFloat16WireVector, Float32WireVector, and Float64WireVector the same way diff --git a/tests/rtllib/pyrtlfloat/test_add_sub.py b/tests/rtllib/pyrtlfloat/test_add_sub.py new file mode 100644 index 00000000..f20177a9 --- /dev/null +++ b/tests/rtllib/pyrtlfloat/test_add_sub.py @@ -0,0 +1,30 @@ +import unittest + +import pyrtl +from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode + + +class TestMultiplication(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + a = pyrtl.Input(bitwidth=16, name="a") + b = pyrtl.Input(bitwidth=16, name="b") + a_floatwv = Float16WireVector() + a_floatwv <<= a + b_floatwv = Float16WireVector() + b_floatwv <<= b + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_add = pyrtl.Output(name="result_add") + result_add <<= a_floatwv + b_floatwv + result_sub = pyrtl.Output(name="result_sub") + result_sub <<= a_floatwv - b_floatwv + self.sim = pyrtl.Simulation() + + def test_multiplication_simple(self): + self.sim.step({"a": 0b0100001000000000, "b": 0b0100010100000000}) + self.assertEqual(self.sim.inspect("result_add"), 0b0100100000000000) + self.assertEqual(self.sim.inspect("result_sub"), 0b1100000000000000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rtllib/pyrtlfloat/test_multiplication.py b/tests/rtllib/pyrtlfloat/test_multiplication.py new file mode 100644 index 00000000..812b0ea5 --- /dev/null +++ b/tests/rtllib/pyrtlfloat/test_multiplication.py @@ -0,0 +1,31 @@ +import unittest + +import pyrtl +from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode + + +class TestMultiplication(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + a = pyrtl.Input(bitwidth=16, name="a") + b = pyrtl.Input(bitwidth=16, name="b") + a_floatwv = Float16WireVector() + a_floatwv <<= a + b_floatwv = Float16WireVector() + b_floatwv <<= b + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= a_floatwv * b_floatwv + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= a_floatwv * b_floatwv + self.sim = pyrtl.Simulation() + + def test_multiplication_simple(self): + self.sim.step({"a": 0b0011111000000000, "b": 0b0011110000000001}) + self.assertEqual(self.sim.inspect("result_rne"), 0b0011111000000010) + self.assertEqual(self.sim.inspect("result_rtz"), 0b0011111000000001) + + +if __name__ == "__main__": + unittest.main() From 15ad9965cd7dcecd2d43ebdb9d60e6fd2f830046 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Thu, 30 Oct 2025 15:49:44 -0700 Subject: [PATCH 2/3] Remove FloatWireVector --- pyrtl/rtllib/pyrtlfloat/__init__.py | 14 ++++-- pyrtl/rtllib/pyrtlfloat/floatoperations.py | 48 ++++++++++++++++++- pyrtl/rtllib/pyrtlfloat/floatwirevector.py | 46 ------------------ tests/rtllib/pyrtlfloat/test_add_sub.py | 10 ++-- .../rtllib/pyrtlfloat/test_multiplication.py | 10 ++-- 5 files changed, 63 insertions(+), 65 deletions(-) delete mode 100644 pyrtl/rtllib/pyrtlfloat/floatwirevector.py diff --git a/pyrtl/rtllib/pyrtlfloat/__init__.py b/pyrtl/rtllib/pyrtlfloat/__init__.py index df407d93..d9b64710 100644 --- a/pyrtl/rtllib/pyrtlfloat/__init__.py +++ b/pyrtl/rtllib/pyrtlfloat/__init__.py @@ -1,6 +1,11 @@ from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode -from .floatoperations import FloatOperations -from .floatwirevector import Float16WireVector +from .floatoperations import ( + BFloat16Operations, + Float16Operations, + Float32Operations, + Float64Operations, + FloatOperations, +) __all__ = [ "FloatingPointType", @@ -8,5 +13,8 @@ "PyrtlFloatConfig", "RoundingMode", "FloatOperations", - "Float16WireVector", + "BFloat16Operations", + "Float16Operations", + "Float32Operations", + "Float64Operations", ] diff --git a/pyrtl/rtllib/pyrtlfloat/floatoperations.py b/pyrtl/rtllib/pyrtlfloat/floatoperations.py index e4d10769..ef081b0a 100644 --- a/pyrtl/rtllib/pyrtlfloat/floatoperations.py +++ b/pyrtl/rtllib/pyrtlfloat/floatoperations.py @@ -2,14 +2,14 @@ from ._add_sub import AddSubHelper from ._multiplication import MultiplicationHelper -from ._types import PyrtlFloatConfig, RoundingMode +from ._types import FloatingPointType, PyrtlFloatConfig, RoundingMode class FloatOperations: default_rounding_mode = RoundingMode.RNE @staticmethod - def multiply( + def mul( config: PyrtlFloatConfig, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector, @@ -31,3 +31,47 @@ def sub( operand_b: pyrtl.WireVector, ) -> pyrtl.WireVector: return AddSubHelper.sub(config, operand_a, operand_b) + + +class _BaseTypedFloatOperations: + _fp_type: FloatingPointType = None + + @classmethod + def mul( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + return FloatOperations.mul(cls._get_config(), operand_a, operand_b) + + @classmethod + def add( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + return FloatOperations.add(cls._get_config(), operand_a, operand_b) + + @classmethod + def sub( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + return FloatOperations.sub(cls._get_config(), operand_a, operand_b) + + @classmethod + def _get_config(cls) -> PyrtlFloatConfig: + return PyrtlFloatConfig( + cls._fp_type.value, FloatOperations.default_rounding_mode + ) + + +class BFloat16Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.BFLOAT16 + + +class Float16Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.FLOAT16 + + +class Float32Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.FLOAT32 + + +class Float64Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.FLOAT64 diff --git a/pyrtl/rtllib/pyrtlfloat/floatwirevector.py b/pyrtl/rtllib/pyrtlfloat/floatwirevector.py deleted file mode 100644 index 0004adb7..00000000 --- a/pyrtl/rtllib/pyrtlfloat/floatwirevector.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -import pyrtl - -from ._types import FloatingPointType, PyrtlFloatConfig, PyrtlFloatException -from .floatoperations import FloatOperations - - -class Float16WireVector(pyrtl.WireVector): - def __init__(self): - super().__init__() - self.bitwidth = 16 - - def __ilshift__(self, other): - if isinstance(other, (pyrtl.WireVector, Float16WireVector)): - super().__ilshift__(other) - else: - msg = ( - "FloatWireVector16 can only be driven by a FloatWireVector16 " - "or a PyRTL WireVector." - ) - raise PyrtlFloatException(msg) - return self - - def _get_config(self) -> PyrtlFloatConfig | None: - return PyrtlFloatConfig( - FloatingPointType.FLOAT16.value, FloatOperations.default_rounding_mode - ) - - def __add__(self, other: Float16WireVector) -> Float16WireVector: - ret = Float16WireVector() - ret <<= FloatOperations.add(self._get_config(), self, other) - return ret - - def __sub__(self, other: Float16WireVector) -> Float16WireVector: - ret = Float16WireVector() - ret <<= FloatOperations.sub(self._get_config(), self, other) - return ret - - def __mul__(self, other: Float16WireVector) -> Float16WireVector: - ret = Float16WireVector() - ret <<= FloatOperations.multiply(self._get_config(), self, other) - return ret - - -# will create BFloat16WireVector, Float32WireVector, and Float64WireVector the same way diff --git a/tests/rtllib/pyrtlfloat/test_add_sub.py b/tests/rtllib/pyrtlfloat/test_add_sub.py index f20177a9..3f006da2 100644 --- a/tests/rtllib/pyrtlfloat/test_add_sub.py +++ b/tests/rtllib/pyrtlfloat/test_add_sub.py @@ -1,7 +1,7 @@ import unittest import pyrtl -from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode +from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode class TestMultiplication(unittest.TestCase): @@ -9,15 +9,11 @@ def setUp(self): pyrtl.reset_working_block() a = pyrtl.Input(bitwidth=16, name="a") b = pyrtl.Input(bitwidth=16, name="b") - a_floatwv = Float16WireVector() - a_floatwv <<= a - b_floatwv = Float16WireVector() - b_floatwv <<= b FloatOperations.default_rounding_mode = RoundingMode.RNE result_add = pyrtl.Output(name="result_add") - result_add <<= a_floatwv + b_floatwv + result_add <<= Float16Operations.add(a, b) result_sub = pyrtl.Output(name="result_sub") - result_sub <<= a_floatwv - b_floatwv + result_sub <<= Float16Operations.sub(a, b) self.sim = pyrtl.Simulation() def test_multiplication_simple(self): diff --git a/tests/rtllib/pyrtlfloat/test_multiplication.py b/tests/rtllib/pyrtlfloat/test_multiplication.py index 812b0ea5..439efabb 100644 --- a/tests/rtllib/pyrtlfloat/test_multiplication.py +++ b/tests/rtllib/pyrtlfloat/test_multiplication.py @@ -1,7 +1,7 @@ import unittest import pyrtl -from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode +from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode class TestMultiplication(unittest.TestCase): @@ -9,16 +9,12 @@ def setUp(self): pyrtl.reset_working_block() a = pyrtl.Input(bitwidth=16, name="a") b = pyrtl.Input(bitwidth=16, name="b") - a_floatwv = Float16WireVector() - a_floatwv <<= a - b_floatwv = Float16WireVector() - b_floatwv <<= b FloatOperations.default_rounding_mode = RoundingMode.RNE result_rne = pyrtl.Output(name="result_rne") - result_rne <<= a_floatwv * b_floatwv + result_rne <<= Float16Operations.mul(a, b) FloatOperations.default_rounding_mode = RoundingMode.RTZ result_rtz = pyrtl.Output(name="result_rtz") - result_rtz <<= a_floatwv * b_floatwv + result_rtz <<= Float16Operations.mul(a, b) self.sim = pyrtl.Simulation() def test_multiplication_simple(self): From e601040bf22447a385844304fc59c1fd23b1dadc Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Tue, 20 Jan 2026 20:14:39 -0800 Subject: [PATCH 3/3] Address comments --- pyrtl/rtllib/pyrtlfloat/_add_sub.py | 725 ++++++++++-------- pyrtl/rtllib/pyrtlfloat/_float_utills.py | 267 ++++--- pyrtl/rtllib/pyrtlfloat/_multiplication.py | 332 ++++---- pyrtl/rtllib/pyrtlfloat/_types.py | 43 +- pyrtl/rtllib/pyrtlfloat/floatoperations.py | 113 ++- tests/rtllib/pyrtlfloat/test_add_sub.py | 521 ++++++++++++- .../rtllib/pyrtlfloat/test_multiplication.py | 284 ++++++- 7 files changed, 1697 insertions(+), 588 deletions(-) diff --git a/pyrtl/rtllib/pyrtlfloat/_add_sub.py b/pyrtl/rtllib/pyrtlfloat/_add_sub.py index 419d647f..21ca90ab 100644 --- a/pyrtl/rtllib/pyrtlfloat/_add_sub.py +++ b/pyrtl/rtllib/pyrtlfloat/_add_sub.py @@ -1,342 +1,447 @@ import pyrtl -from ._float_utills import FloatUtils +from ._float_utills import ( + get_exponent, + get_mantissa, + get_sign, + is_inf, + is_nan, + is_zero, + make_denormals_zero, + make_inf, + make_largest_finite_number, + make_nan, + make_zero, +) from ._types import PyrtlFloatConfig, RoundingMode -class AddSubHelper: - @staticmethod - def add( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, - ) -> pyrtl.WireVector: - fp_type_props = config.fp_type_properties - rounding_mode = config.rounding_mode - num_exp_bits = fp_type_props.num_exponent_bits - num_mant_bits = fp_type_props.num_mantissa_bits - total_bits = num_exp_bits + num_mant_bits + 1 +def add( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, +) -> pyrtl.WireVector: + """ + Performs floating point addition of two WireVectors. - operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a) - operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b) + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the addition as a WireVector. + """ + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + total_bits = num_exp_bits + num_mant_bits + 1 - # operand_smaller is the operand with the smaller absolute value and - # operand_larger is the operand with the larger absolute value - operand_smaller = pyrtl.WireVector(bitwidth=total_bits) - operand_larger = pyrtl.WireVector(bitwidth=total_bits) + # Denormalized numbers are not supported, so we flush them to zero. + operand_a_daz = make_denormals_zero(fp_type_props, operand_a) + operand_b_daz = make_denormals_zero(fp_type_props, operand_b) - with pyrtl.conditional_assignment: - exponent_and_mantissa_len = num_mant_bits + num_exp_bits - with ( - operand_a_daz[:exponent_and_mantissa_len] - < operand_b_daz[:exponent_and_mantissa_len] - ): - operand_smaller |= operand_a_daz - operand_larger |= operand_b_daz - with pyrtl.otherwise: - operand_smaller |= operand_b_daz - operand_larger |= operand_a_daz + # operand_smaller is the operand with the smaller absolute value and + # operand_larger is the operand with the larger absolute value. + operand_smaller = pyrtl.WireVector(bitwidth=total_bits) + operand_larger = pyrtl.WireVector(bitwidth=total_bits) - smaller_operand_sign = FloatUtils.get_sign(fp_type_props, operand_smaller) - larger_operand_sign = FloatUtils.get_sign(fp_type_props, operand_larger) - smaller_operand_exponent = FloatUtils.get_exponent( - fp_type_props, operand_smaller - ) - larger_operand_exponent = FloatUtils.get_exponent(fp_type_props, operand_larger) - smaller_operand_mantissa = pyrtl.concat( - pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_smaller) - ) - larger_operand_mantissa = pyrtl.concat( - pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_larger) - ) + # Determine which operand is smaller/larger, then assign operand_smaller and + # operand_larger accordingly. + with pyrtl.conditional_assignment: + exponent_and_mantissa_len = num_mant_bits + num_exp_bits + with ( + operand_a_daz[:exponent_and_mantissa_len] + < operand_b_daz[:exponent_and_mantissa_len] + ): + operand_smaller |= operand_a_daz + operand_larger |= operand_b_daz + with pyrtl.otherwise: + operand_smaller |= operand_b_daz + operand_larger |= operand_a_daz - exponent_diff = larger_operand_exponent - smaller_operand_exponent - smaller_mantissa_shifted = pyrtl.shift_right_logical( - smaller_operand_mantissa, exponent_diff - ) - grs = pyrtl.WireVector(bitwidth=3) # guard, round, sticky bits for rounding - with pyrtl.conditional_assignment: - with exponent_diff >= 2: - guard_and_round = pyrtl.shift_right_logical( - smaller_operand_mantissa, exponent_diff - 2 - )[:2] - mask = ( - pyrtl.shift_left_logical( - pyrtl.Const(1, bitwidth=num_mant_bits), exponent_diff - 2 - ) - - 1 - ) - sticky = (smaller_operand_mantissa & mask) != 0 - grs |= pyrtl.concat(guard_and_round, sticky) - with exponent_diff == 1: - grs |= pyrtl.concat( - smaller_operand_mantissa[0], pyrtl.Const(0, bitwidth=2) - ) - with pyrtl.otherwise: - grs |= 0 - smaller_mantissa_shifted_grs = pyrtl.concat(smaller_mantissa_shifted, grs) - larger_mantissa_extended = pyrtl.concat( - larger_operand_mantissa, pyrtl.Const(0, bitwidth=3) - ) + # Extract the sign, exponent, and mantissa of both operands. + smaller_operand_sign = get_sign(fp_type_props, operand_smaller) + larger_operand_sign = get_sign(fp_type_props, operand_larger) + smaller_operand_exponent = get_exponent(fp_type_props, operand_smaller) + larger_operand_exponent = get_exponent(fp_type_props, operand_larger) + smaller_operand_mantissa = pyrtl.concat( + pyrtl.Const(1), get_mantissa(fp_type_props, operand_smaller) + ) + larger_operand_mantissa = pyrtl.concat( + pyrtl.Const(1), get_mantissa(fp_type_props, operand_larger) + ) - sum_exponent, sum_mantissa, sum_grs, sum_carry = AddSubHelper._add_operands( - larger_operand_exponent, - smaller_mantissa_shifted_grs, - larger_mantissa_extended, - ) + # Align mantissas by shifting the smaller one to match the larger's exponent. + smaller_mantissa_shift_amount = larger_operand_exponent - smaller_operand_exponent + smaller_mantissa_shifted = pyrtl.shift_right_logical( + smaller_operand_mantissa, smaller_mantissa_shift_amount + ) - sub_exponent, sub_mantissa, sub_grs, num_leading_zeros = ( - AddSubHelper._sub_operands( - num_mant_bits, - larger_operand_exponent, - smaller_mantissa_shifted_grs, - larger_mantissa_extended, + # RNE rounding uses the guard, round, and sticky bits. + # When shifting the smaller mantissa to the right, some bits are shifted out. + # The first bit shifted out becomes the guard bit, the second becomes the round bit, + # and any remaining bits are ORed together to form the sticky bit. + # https://drilian.com/posts/2023.01.10-floating-point-numbers-and-rounding/ + grs = pyrtl.WireVector(bitwidth=3) # guard, round, sticky bits + with pyrtl.conditional_assignment: + # If the smaller mantissa is shifted by 2 or more, the first two bits + # shifted out are the guard and round bits, and the sticky bit is + # the OR of all remaining bits. + with smaller_mantissa_shift_amount >= 2: + guard_and_round = pyrtl.shift_right_logical( + smaller_operand_mantissa, smaller_mantissa_shift_amount - 2 + )[:2] + # Mask with the least significant (shift_amount - 2) bits set to 1 + mask = ( + pyrtl.shift_left_logical( + pyrtl.Const(1, bitwidth=num_mant_bits), + smaller_mantissa_shift_amount - 2, + ) + - 1 ) + sticky = (smaller_operand_mantissa & mask) != 0 + grs |= pyrtl.concat(guard_and_round, sticky) + # If the smaller mantissa is shifted by 1, the first bit shifted out + # is the guard bit, the round bit and sticky bit are both 0. + with smaller_mantissa_shift_amount == 1: + grs |= pyrtl.concat(smaller_operand_mantissa[0], pyrtl.Const(0, bitwidth=2)) + # If not shifted, guard, round, and sticky bits are all 0. + with pyrtl.otherwise: + grs |= 0 + + # Concatenate the shifted smaller mantissa with the guard, round, and sticky bits. + smaller_mantissa_shifted_grs = pyrtl.concat(smaller_mantissa_shifted, grs) + + # Extend the larger mantissa by concatenating three zeros so it aligns with the + # smaller mantissa, which was extended with GRS bits. + larger_mantissa_extended = pyrtl.concat( + larger_operand_mantissa, pyrtl.Const(0, bitwidth=3) + ) + + # Perform addition of operands. + sum_exponent, sum_mantissa, sum_grs, sum_carry = _add_operands( + larger_operand_exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, + ) + + # Perform subtraction of operands. + sub_exponent, sub_mantissa, sub_grs, num_leading_zeros = _sub_operands( + num_mant_bits, + larger_operand_exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, + ) + + # Exponent and mantissa for raw addition or subtraction result, before + # rounding and handling special cases. + raw_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + raw_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + if rounding_mode == RoundingMode.RNE: + raw_result_grs = pyrtl.WireVector(bitwidth=3) + + # Determine whether we need to add or subtract the operands. + with pyrtl.conditional_assignment: + # If the operands have the same sign, we perform addition. + # For example, (+a) + (+b) or (-a) + (-b). + with smaller_operand_sign == larger_operand_sign: + raw_result_exponent |= sum_exponent + raw_result_mantissa |= sum_mantissa + if rounding_mode == RoundingMode.RNE: + raw_result_grs |= sum_grs + # If the operands have different signs, we perform subtraction. + # For example, (+a) + (-b) or (-a) + (+b). + with pyrtl.otherwise: + raw_result_exponent |= sub_exponent + raw_result_mantissa |= sub_mantissa + if rounding_mode == RoundingMode.RNE: + raw_result_grs |= sub_grs + + # Round the result if using RNE rounding mode. + if rounding_mode == RoundingMode.RNE: + ( + raw_result_rounded_exponent, + raw_result_rounded_mantissa, + rounding_exponent_incremented, + ) = _round( + num_mant_bits, + num_exp_bits, + raw_result_exponent, + raw_result_mantissa, + raw_result_grs, ) - # WireVectors for the raw addition or subtraction result, before handling - # special cases - raw_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) - raw_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) - if rounding_mode == RoundingMode.RNE: - raw_result_grs = pyrtl.WireVector(bitwidth=3) + # Check whether the operands are special cases: NaN, infinity, or zero. + smaller_operand_nan = is_nan(fp_type_props, operand_smaller) + larger_operand_nan = is_nan(fp_type_props, operand_larger) + smaller_operand_inf = is_inf(fp_type_props, operand_smaller) + larger_operand_inf = is_inf(fp_type_props, operand_larger) + smaller_operand_zero = is_zero(fp_type_props, operand_smaller) + larger_operand_zero = is_zero(fp_type_props, operand_larger) - with pyrtl.conditional_assignment: - with smaller_operand_sign == larger_operand_sign: # add - raw_result_exponent |= sum_exponent - raw_result_mantissa |= sum_mantissa - if rounding_mode == RoundingMode.RNE: - raw_result_grs |= sum_grs - with pyrtl.otherwise: # sub - raw_result_exponent |= sub_exponent - raw_result_mantissa |= sub_mantissa - if rounding_mode == RoundingMode.RNE: - raw_result_grs |= sub_grs + # WireVectors for the final result after handling special cases. + final_result_sign = pyrtl.WireVector(bitwidth=1) + final_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + final_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) - if rounding_mode == RoundingMode.RNE: - ( - raw_result_rounded_exponent, - raw_result_rounded_mantissa, - rounding_exponent_incremented, - ) = AddSubHelper._round( - num_mant_bits, - num_exp_bits, - raw_result_exponent, - raw_result_mantissa, - raw_result_grs, + # Handle special cases. + with pyrtl.conditional_assignment: + # If either operand is NaN, or if both operands are infinities with + # opposite signs, the result is NaN. + with ( + smaller_operand_nan + | larger_operand_nan + | ( + smaller_operand_inf + & larger_operand_inf + & (larger_operand_sign != smaller_operand_sign) ) + ): + final_result_sign |= larger_operand_sign + make_nan(fp_type_props, final_result_exponent, final_result_mantissa) - smaller_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_smaller) - larger_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_larger) - smaller_operand_inf = FloatUtils.is_inf(fp_type_props, operand_smaller) - larger_operand_inf = FloatUtils.is_inf(fp_type_props, operand_larger) - smaller_operand_zero = FloatUtils.is_zero(fp_type_props, operand_smaller) - larger_operand_zero = FloatUtils.is_zero(fp_type_props, operand_larger) + # If either operand is infinity, result is infinity with that sign. + with smaller_operand_inf: + final_result_sign |= larger_operand_sign + make_inf(fp_type_props, final_result_exponent, final_result_mantissa) + with larger_operand_inf: + final_result_sign |= larger_operand_sign + make_inf(fp_type_props, final_result_exponent, final_result_mantissa) - # WireVectors for the final result after handling special cases - final_result_sign = pyrtl.WireVector(bitwidth=1) - final_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) - final_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + # If operands are equal in magnitude but opposite in sign, the result is +0. + with ( + (smaller_operand_mantissa == larger_operand_mantissa) + & (smaller_operand_exponent == larger_operand_exponent) + & (larger_operand_sign != smaller_operand_sign) + ): + final_result_sign |= 0 + make_zero(final_result_exponent, final_result_mantissa) - # handle special cases - with pyrtl.conditional_assignment: - # if either operand is NaN or both operands are infinity of opposite signs, - # the result is NaN - with ( - smaller_operand_nan - | larger_operand_nan - | ( - smaller_operand_inf - & larger_operand_inf - & (larger_operand_sign != smaller_operand_sign) - ) - ): - final_result_sign |= larger_operand_sign - FloatUtils.make_output_NaN( - fp_type_props, final_result_exponent, final_result_mantissa - ) - # infinities - with smaller_operand_inf: - final_result_sign |= larger_operand_sign - FloatUtils.make_output_inf( - fp_type_props, final_result_exponent, final_result_mantissa - ) - with larger_operand_inf: - final_result_sign |= larger_operand_sign - FloatUtils.make_output_inf( - fp_type_props, final_result_exponent, final_result_mantissa - ) - # +num + -num = +0 - with ( - (smaller_operand_mantissa == larger_operand_mantissa) - & (smaller_operand_exponent == larger_operand_exponent) - & (larger_operand_sign != smaller_operand_sign) - ): - final_result_sign |= 0 - FloatUtils.make_output_zero( - final_result_exponent, final_result_mantissa - ) - with smaller_operand_zero: - final_result_sign |= larger_operand_sign - final_result_mantissa |= larger_operand_mantissa - final_result_exponent |= larger_operand_exponent - with larger_operand_zero: - final_result_sign |= smaller_operand_sign - final_result_mantissa |= smaller_operand_mantissa - final_result_exponent |= smaller_operand_exponent - # overflow and underflow - initial_larger_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2) + # If either operand is zero, the result is the other operand. + with smaller_operand_zero: + final_result_sign |= larger_operand_sign + final_result_mantissa |= larger_operand_mantissa + final_result_exponent |= larger_operand_exponent + with larger_operand_zero: + final_result_sign |= smaller_operand_sign + final_result_mantissa |= smaller_operand_mantissa + final_result_exponent |= smaller_operand_exponent + + # Check for overflow on addition. + # We check for overflow by calculating the max value of the larger + # operand's exponent. This value can vary depending on the operands. + # If there was a carry out from the addition, the result exponent is + # incremented by 1. Additionally, if rounding causes the exponent to + # increment, we need to account for that as well. Therefore, we + # subtract these increments from the absolute maximum exponent, which + # is one less than the all-1s exponent (reserved for infinity/NaN). + initial_larger_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2) + if rounding_mode == RoundingMode.RNE: + larger_exponent_max_value = ( + initial_larger_exponent_max_value + - sum_carry + - rounding_exponent_incremented + ) + else: + larger_exponent_max_value = initial_larger_exponent_max_value - sum_carry + # Checks if an addition was performed and the result overflowed. + with (smaller_operand_sign == larger_operand_sign) & ( + larger_operand_exponent > larger_exponent_max_value + ): + final_result_sign |= larger_operand_sign if rounding_mode == RoundingMode.RNE: - larger_exponent_max_value = ( - initial_larger_exponent_max_value - - sum_carry - - rounding_exponent_incremented - ) + make_inf(fp_type_props, final_result_exponent, final_result_mantissa) else: - larger_exponent_max_value = ( - initial_larger_exponent_max_value - sum_carry + make_largest_finite_number( + fp_type_props, final_result_exponent, final_result_mantissa ) - initial_larger_exponent_min_value = pyrtl.Const(1) + + # Check for underflow on subtraction. + # We check for underflow by computing the min value of the larger + # operand's exponent. As with overflow, this value can vary depending + # on the operands. We subtract the number of leading zeros from the + # larger exponent to obtain the subtraction exponent. Additionally, + # if rounding causes the exponent to increment, we need to account + # for that. Therefore, we add the number of leading zeros and + # subtract the rounding increment from the absolute minimum exponent, + # which is one greater than the all-0s exponent (reserved for + # zero and denormals). + initial_larger_exponent_min_value = pyrtl.Const(1) + if rounding_mode == RoundingMode.RNE: + larger_exponent_min_value = ( + initial_larger_exponent_min_value + + num_leading_zeros + - rounding_exponent_incremented + ) + else: + larger_exponent_min_value = ( + initial_larger_exponent_min_value + num_leading_zeros + ) + # Checks if a subtraction was performed and the result underflowed. + with (smaller_operand_sign != larger_operand_sign) & ( + larger_operand_exponent < larger_exponent_min_value + ): + final_result_sign |= larger_operand_sign + make_zero(final_result_exponent, final_result_mantissa) + with pyrtl.otherwise: + final_result_sign |= larger_operand_sign if rounding_mode == RoundingMode.RNE: - larger_exponent_min_value = ( - initial_larger_exponent_min_value - + num_leading_zeros - - rounding_exponent_incremented - ) + final_result_exponent |= raw_result_rounded_exponent + final_result_mantissa |= raw_result_rounded_mantissa else: - larger_exponent_min_value = ( - initial_larger_exponent_min_value + num_leading_zeros - ) - with (smaller_operand_sign == larger_operand_sign) & ( - larger_operand_exponent > larger_exponent_max_value - ): # detect overflow on addition - final_result_sign |= larger_operand_sign - if rounding_mode == RoundingMode.RNE: - FloatUtils.make_output_inf( - fp_type_props, final_result_exponent, final_result_mantissa - ) - else: - FloatUtils.make_output_largest_finite_number( - fp_type_props, final_result_exponent, final_result_mantissa - ) - with (smaller_operand_sign != larger_operand_sign) & ( - larger_operand_exponent < larger_exponent_min_value - ): # detect underflow on subtraction - final_result_sign |= larger_operand_sign - FloatUtils.make_output_zero( - final_result_exponent, final_result_mantissa - ) - with pyrtl.otherwise: - final_result_sign |= larger_operand_sign - if rounding_mode == RoundingMode.RNE: - final_result_exponent |= raw_result_rounded_exponent - final_result_mantissa |= raw_result_rounded_mantissa - else: - final_result_exponent |= raw_result_exponent - final_result_mantissa |= raw_result_mantissa - - return pyrtl.concat( - final_result_sign, final_result_exponent, final_result_mantissa - ) + final_result_exponent |= raw_result_exponent + final_result_mantissa |= raw_result_mantissa - @staticmethod - def sub( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, - ) -> pyrtl.WireVector: - num_exp_bits = config.fp_type_properties.num_exponent_bits - num_mant_bits = config.fp_type_properties.num_mantissa_bits - operand_b_negated = operand_b ^ pyrtl.concat( - pyrtl.Const(1, bitwidth=1), - pyrtl.Const(0, bitwidth=num_exp_bits + num_mant_bits), - ) - return AddSubHelper.add(config, operand_a, operand_b_negated) - - @staticmethod - def _add_operands( - larger_operand_exponent: pyrtl.WireVector, - smaller_mantissa_shifted_grs: pyrtl.WireVector, - larger_mantissa_extended: pyrtl.WireVector, - ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: - sum_mantissa_grs = pyrtl.WireVector() - sum_mantissa_grs <<= larger_mantissa_extended + smaller_mantissa_shifted_grs - sum_carry = sum_mantissa_grs[-1] - sum_mantissa = pyrtl.select( - sum_carry, sum_mantissa_grs[4:], sum_mantissa_grs[3:-1] - ) - sum_grs = pyrtl.select( - sum_carry, - pyrtl.concat(sum_mantissa_grs[2:4], sum_mantissa_grs[:2] != 0), - sum_mantissa_grs[:3], - ) - sum_exponent = pyrtl.select( - sum_carry, larger_operand_exponent + 1, larger_operand_exponent - ) - return sum_exponent, sum_mantissa, sum_grs, sum_carry - - @staticmethod - def _sub_operands( - num_mant_bits: int, - larger_operand_exponent: pyrtl.WireVector, - smaller_mantissa_shifted_grs: pyrtl.WireVector, - larger_mantissa_extended: pyrtl.WireVector, - ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: - def leading_zero_priority_encoder(wire: pyrtl.WireVector, length: int): - out = pyrtl.WireVector( - bitwidth=pyrtl.infer_val_and_bitwidth(length - 1).bitwidth - ) - with pyrtl.conditional_assignment: - for i in range(wire.bitwidth - 1, wire.bitwidth - length - 1, -1): - with wire[i]: - out |= wire.bitwidth - i - 1 - return out - - sub_mantissa_grs = pyrtl.WireVector(bitwidth=num_mant_bits + 4) - sub_mantissa_grs <<= larger_mantissa_extended - smaller_mantissa_shifted_grs - num_leading_zeros = leading_zero_priority_encoder( - sub_mantissa_grs, num_mant_bits + 1 - ) - sub_mantissa_grs_shifted = pyrtl.shift_left_logical( - sub_mantissa_grs, num_leading_zeros + return pyrtl.concat(final_result_sign, final_result_exponent, final_result_mantissa) + + +def sub( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, +) -> pyrtl.WireVector: + """ + Performs floating point subtraction of two WireVectors. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the subtraction as a WireVector. + """ + num_exp_bits = config.fp_type_properties.num_exponent_bits + num_mant_bits = config.fp_type_properties.num_mantissa_bits + operand_b_negated = operand_b ^ pyrtl.concat( + pyrtl.Const(1, bitwidth=1), + pyrtl.Const(0, bitwidth=num_exp_bits + num_mant_bits), + ) + return add(config, operand_a, operand_b_negated) + + +def _add_operands( + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, +) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: + """ + Helper function for performing addition of two floating point mantissas. + + :param larger_operand_exponent: Exponent of the larger operand. + :param smaller_mantissa_shifted_grs: Mantissa of the smaller operand + shifted to align with the larger operand and concatenated with GRS. + :param larger_mantissa_extended: Larger mantissa with three zeros. + :return: Tuple of (exponent, mantissa, GRS bits, carry bit). + """ + sum_mantissa_grs = pyrtl.WireVector() + sum_mantissa_grs <<= larger_mantissa_extended + smaller_mantissa_shifted_grs + sum_carry = sum_mantissa_grs[-1] + # Pick the correct bits for the mantissa and GRS based on carry out. + sum_mantissa = pyrtl.select(sum_carry, sum_mantissa_grs[4:], sum_mantissa_grs[3:-1]) + sum_grs = pyrtl.select( + sum_carry, + pyrtl.concat(sum_mantissa_grs[2:4], sum_mantissa_grs[:2] != 0), + sum_mantissa_grs[:3], + ) + # Increment the exponent if there was a carry out. + sum_exponent = pyrtl.select( + sum_carry, larger_operand_exponent + 1, larger_operand_exponent + ) + return sum_exponent, sum_mantissa, sum_grs, sum_carry + + +def _sub_operands( + num_mant_bits: int, + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, +) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: + """ + Helper function for performing subtraction of two floating point mantissas. + + :param num_mant_bits: Number of mantissa bits. + :param larger_operand_exponent: Exponent of the larger operand. + :param smaller_mantissa_shifted_grs: Mantissa of the smaller operand + shifted to align with the larger operand and concatenated with GRS. + :param larger_mantissa_extended: Larger mantissa with three zeros. + :return: Tuple of (exponent, mantissa, GRS bits, num leading zeros). + """ + + # Priority encoder that counts the number of leading zeros in a WireVector. + def leading_zero_priority_encoder(wire: pyrtl.WireVector, length: int): + out = pyrtl.WireVector( + bitwidth=pyrtl.infer_val_and_bitwidth(length - 1).bitwidth ) - sub_mantissa = sub_mantissa_grs_shifted[3:] - sub_grs = sub_mantissa_grs_shifted[:3] - sub_exponent = larger_operand_exponent - num_leading_zeros - return sub_exponent, sub_mantissa, sub_grs, num_leading_zeros - - @staticmethod - def _round( - num_mant_bits: int, - num_exp_bits: int, - raw_result_exponent: pyrtl.WireVector, - raw_result_mantissa: pyrtl.WireVector, - raw_result_grs: pyrtl.WireVector, - ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: - last = raw_result_mantissa[0] - guard = raw_result_grs[2] - round = raw_result_grs[1] - sticky = raw_result_grs[0] - round_up = guard & (last | round | sticky) - raw_result_rounded_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) - raw_result_rounded_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) - rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1) with pyrtl.conditional_assignment: - with round_up: - with raw_result_mantissa == (1 << num_mant_bits) - 1: - raw_result_rounded_mantissa |= 0 - raw_result_rounded_exponent |= raw_result_exponent + 1 - rounding_exponent_incremented |= 1 - with pyrtl.otherwise: - raw_result_rounded_mantissa |= raw_result_mantissa + 1 - raw_result_rounded_exponent |= raw_result_exponent - rounding_exponent_incremented |= 0 + for i in range(wire.bitwidth - 1, wire.bitwidth - length - 1, -1): + with wire[i]: + out |= wire.bitwidth - i - 1 + return out + + sub_mantissa_grs = pyrtl.WireVector(bitwidth=num_mant_bits + 4) + sub_mantissa_grs <<= larger_mantissa_extended - smaller_mantissa_shifted_grs + # Normalize result by shifting left until leading 1 is in position. + num_leading_zeros = leading_zero_priority_encoder( + sub_mantissa_grs, num_mant_bits + 1 + ) + sub_mantissa_grs_shifted = pyrtl.shift_left_logical( + sub_mantissa_grs, num_leading_zeros + ) + sub_mantissa = sub_mantissa_grs_shifted[3:] + sub_grs = sub_mantissa_grs_shifted[:3] + # Adjust the exponent by subtracting the number of leading zeros. + sub_exponent = larger_operand_exponent - num_leading_zeros + return sub_exponent, sub_mantissa, sub_grs, num_leading_zeros + + +def _round( + num_mant_bits: int, + num_exp_bits: int, + raw_result_exponent: pyrtl.WireVector, + raw_result_mantissa: pyrtl.WireVector, + raw_result_grs: pyrtl.WireVector, +) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: + """ + Round the floating point result using round to nearest, ties to even (RNE). + + Uses the GRS bits to determine if the result needs to be rounded up. + + :param num_mant_bits: Number of mantissa bits. + :param num_exp_bits: Number of exponent bits. + :param raw_result_exponent: Exponent of the raw result before rounding. + :param raw_result_mantissa: Mantissa of the raw result before rounding. + :param raw_result_grs: GRS bits of the raw result before rounding. + :return: Tuple of (rounded exponent, rounded mantissa, exponent + incremented flag). + """ + last = raw_result_mantissa[0] + guard = raw_result_grs[2] + round = raw_result_grs[1] + sticky = raw_result_grs[0] + # If guard bit is not set, number is closer to smaller value: no round up. + # If guard bit is set and round or sticky is set, round up. + # If guard bit is set but round and sticky are not set, value is exactly + # halfway. Following round-to-nearest ties-to-even, round up if last bit + # of mantissa is 1 (to make it even); otherwise do not round up. + round_up = guard & (last | round | sticky) + raw_result_rounded_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + raw_result_rounded_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + # Whether exponent was incremented due to rounding (for overflow check). + rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1) + with pyrtl.conditional_assignment: + with round_up: + # If rounding causes a mantissa overflow, we need to increment the exponent. + with raw_result_mantissa == (1 << num_mant_bits) - 1: + raw_result_rounded_mantissa |= 0 + raw_result_rounded_exponent |= raw_result_exponent + 1 + rounding_exponent_incremented |= 1 with pyrtl.otherwise: - raw_result_rounded_mantissa |= raw_result_mantissa + raw_result_rounded_mantissa |= raw_result_mantissa + 1 raw_result_rounded_exponent |= raw_result_exponent rounding_exponent_incremented |= 0 - return ( - raw_result_rounded_exponent, - raw_result_rounded_mantissa, - rounding_exponent_incremented, - ) + with pyrtl.otherwise: + raw_result_rounded_mantissa |= raw_result_mantissa + raw_result_rounded_exponent |= raw_result_exponent + rounding_exponent_incremented |= 0 + return ( + raw_result_rounded_exponent, + raw_result_rounded_mantissa, + rounding_exponent_incremented, + ) diff --git a/pyrtl/rtllib/pyrtlfloat/_float_utills.py b/pyrtl/rtllib/pyrtlfloat/_float_utills.py index 0ae58329..2f3bda1d 100644 --- a/pyrtl/rtllib/pyrtlfloat/_float_utills.py +++ b/pyrtl/rtllib/pyrtlfloat/_float_utills.py @@ -3,102 +3,171 @@ from ._types import FPTypeProperties -class FloatUtils: - @staticmethod - def get_sign(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: - return wire[fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits] - - @staticmethod - def get_exponent( - fp_prop: FPTypeProperties, wire: pyrtl.WireVector - ) -> pyrtl.WireVector: - return wire[ - fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits - + fp_prop.num_exponent_bits - ] - - @staticmethod - def get_mantissa( - fp_prop: FPTypeProperties, wire: pyrtl.WireVector - ) -> pyrtl.WireVector: - return wire[: fp_prop.num_mantissa_bits] - - @staticmethod - def is_zero(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: - return (FloatUtils.get_mantissa(fp_prop, wire) == 0) & ( - FloatUtils.get_exponent(fp_prop, wire) == 0 - ) - - @staticmethod - def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: - return (FloatUtils.get_mantissa(fp_prop, wire) == 0) & ( - FloatUtils.get_exponent(fp_prop, wire) - == (1 << fp_prop.num_exponent_bits) - 1 - ) - - @staticmethod - def is_denormalized( - fp_prop: FPTypeProperties, wire: pyrtl.WireVector - ) -> pyrtl.WireVector: - return (FloatUtils.get_mantissa(fp_prop, wire) != 0) & ( - FloatUtils.get_exponent(fp_prop, wire) == 0 - ) - - @staticmethod - def is_NaN(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: - return (FloatUtils.get_mantissa(fp_prop, wire) != 0) & ( - FloatUtils.get_exponent(fp_prop, wire) - == (1 << fp_prop.num_exponent_bits) - 1 - ) - - @staticmethod - def make_denormals_zero( - fp_prop: FPTypeProperties, wire: pyrtl.WireVector - ) -> pyrtl.WireVector: - out = pyrtl.WireVector( - bitwidth=fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits + 1 - ) - with pyrtl.conditional_assignment: - with FloatUtils.get_exponent(fp_prop, wire) == 0: - out |= pyrtl.concat( - FloatUtils.get_sign(fp_prop, wire), - FloatUtils.get_exponent(fp_prop, wire), - pyrtl.Const(0, bitwidth=fp_prop.num_mantissa_bits), - ) - with pyrtl.otherwise: - out |= wire - return out - - @staticmethod - def make_output_inf( - fp_prop: FPTypeProperties, - exponent: pyrtl.WireVector, - mantissa: pyrtl.WireVector, - ) -> None: - exponent |= (1 << fp_prop.num_exponent_bits) - 1 - mantissa |= 0 - - @staticmethod - def make_output_NaN( - fp_prop: FPTypeProperties, - exponent: pyrtl.WireVector, - mantissa: pyrtl.WireVector, - ) -> None: - exponent |= (1 << fp_prop.num_exponent_bits) - 1 - mantissa |= 1 << (fp_prop.num_mantissa_bits - 1) - - @staticmethod - def make_output_zero( - exponent: pyrtl.WireVector, mantissa: pyrtl.WireVector - ) -> None: - exponent |= 0 - mantissa |= 0 - - @staticmethod - def make_output_largest_finite_number( - fp_prop: FPTypeProperties, - exponent: pyrtl.WireVector, - mantissa: pyrtl.WireVector, - ) -> None: - exponent |= (1 << fp_prop.num_exponent_bits) - 2 - mantissa |= (1 << fp_prop.num_mantissa_bits) - 1 +def get_sign(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + """ + Returns the sign bit of floating point number. + + :param fp_prop: Floating point type properties. + :param wire: WireVector holding the floating point number. + :return: WireVector holding the sign bit. + """ + return wire[fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits] + + +def get_exponent(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + """ + Returns the exponent bits of floating point number. + + :param fp_prop: Floating point type properties. + :param wire: WireVector holding the floating point number. + :return: WireVector holding the exponent bits. + """ + return wire[ + fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits + + fp_prop.num_exponent_bits + ] + + +def get_mantissa(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + """ + Returns the mantissa bits of floating point number. + + :param fp_prop: Floating point type properties. + :param wire: WireVector holding the floating point number. + :return: WireVector holding the mantissa bits. + """ + return wire[: fp_prop.num_mantissa_bits] + + +def is_zero(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + """ + Returns whether the floating point number is zero. + + :param fp_prop: Floating point type properties. + :param wire: WireVector holding the floating point number. + :return: 1-bit WireVector indicating whether the number is zero. + """ + return (get_mantissa(fp_prop, wire) == 0) & (get_exponent(fp_prop, wire) == 0) + + +def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + """ + Returns whether the floating point number is infinity. + + :param fp_prop: Floating point type properties. + :param wire: WireVector holding the floating point number. + :return: 1-bit WireVector indicating whether the number is infinity. + """ + return (get_mantissa(fp_prop, wire) == 0) & ( + get_exponent(fp_prop, wire) == (1 << fp_prop.num_exponent_bits) - 1 + ) + + +def is_denormalized( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector +) -> pyrtl.WireVector: + """ + Returns whether the floating point number is denormalized. + + :param fp_prop: Floating point type properties. + :param wire: WireVector holding the floating point number. + :return: 1-bit WireVector indicating whether the number is denormalized. + """ + return (get_mantissa(fp_prop, wire) != 0) & (get_exponent(fp_prop, wire) == 0) + + +def is_nan(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + """ + Returns whether the floating point number is NaN. + + :param fp_prop: Floating point type properties. + :param wire: WireVector holding the floating point number. + :return: 1-bit WireVector indicating whether the number is NaN. + """ + return (get_mantissa(fp_prop, wire) != 0) & ( + get_exponent(fp_prop, wire) == (1 << fp_prop.num_exponent_bits) - 1 + ) + + +def make_denormals_zero( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector +) -> pyrtl.WireVector: + """ + Returns zero if denormalized, else original number. + + :param fp_prop: Floating point type properties. + :param wire: WireVector holding the floating point number. + :return: WireVector holding the resulting floating point number. + """ + out = pyrtl.WireVector( + bitwidth=fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits + 1 + ) + with pyrtl.conditional_assignment: + with get_exponent(fp_prop, wire) == 0: + out |= pyrtl.concat( + get_sign(fp_prop, wire), + get_exponent(fp_prop, wire), + pyrtl.Const(0, bitwidth=fp_prop.num_mantissa_bits), + ) + with pyrtl.otherwise: + out |= wire + return out + + +def make_inf( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, +) -> None: + """ + Sets the exponent and mantissa to represent infinity. + + :param fp_prop: Floating point type properties. + :param exponent: WireVector to set the exponent bits. + :param mantissa: WireVector to set the mantissa bits. + """ + exponent |= (1 << fp_prop.num_exponent_bits) - 1 + mantissa |= 0 + + +def make_nan( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, +) -> None: + """ + Sets the exponent and mantissa to represent NaN. + + :param fp_prop: Floating point type properties. + :param exponent: WireVector to set the exponent bits. + :param mantissa: WireVector to set the mantissa bits. + """ + exponent |= (1 << fp_prop.num_exponent_bits) - 1 + mantissa |= 1 << (fp_prop.num_mantissa_bits - 1) + + +def make_zero(exponent: pyrtl.WireVector, mantissa: pyrtl.WireVector) -> None: + """ + Sets the exponent and mantissa to represent zero. + + :param exponent: WireVector to set the exponent bits. + :param mantissa: WireVector to set the mantissa bits. + """ + exponent |= 0 + mantissa |= 0 + + +def make_largest_finite_number( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, +) -> None: + """ + Sets the exponent and mantissa to represent the largest finite number. + + :param fp_prop: Floating point type properties. + :param exponent: WireVector to set the exponent bits. + :param mantissa: WireVector to set the mantissa bits. + """ + exponent |= (1 << fp_prop.num_exponent_bits) - 2 + mantissa |= (1 << fp_prop.num_mantissa_bits) - 1 diff --git a/pyrtl/rtllib/pyrtlfloat/_multiplication.py b/pyrtl/rtllib/pyrtlfloat/_multiplication.py index a0de7d25..b7344423 100644 --- a/pyrtl/rtllib/pyrtlfloat/_multiplication.py +++ b/pyrtl/rtllib/pyrtlfloat/_multiplication.py @@ -1,161 +1,197 @@ import pyrtl -from ._float_utills import FloatUtils +from ._float_utills import ( + get_exponent, + get_mantissa, + get_sign, + is_denormalized, + is_inf, + is_nan, + is_zero, + make_denormals_zero, + make_inf, + make_largest_finite_number, + make_nan, + make_zero, +) from ._types import PyrtlFloatConfig, RoundingMode -class MultiplicationHelper: - @staticmethod - def multiply( - config: PyrtlFloatConfig, - operand_a: pyrtl.WireVector, - operand_b: pyrtl.WireVector, - ) -> pyrtl.WireVector: - fp_type_props = config.fp_type_properties - rounding_mode = config.rounding_mode - num_exp_bits = fp_type_props.num_exponent_bits - num_mant_bits = fp_type_props.num_mantissa_bits - - operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a) - operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b) - a_sign = FloatUtils.get_sign(fp_type_props, operand_a_daz) - b_sign = FloatUtils.get_sign(fp_type_props, operand_b_daz) - a_exponent = FloatUtils.get_exponent(fp_type_props, operand_a_daz) - b_exponent = FloatUtils.get_exponent(fp_type_props, operand_b_daz) - - exponent_bias = 2 ** (fp_type_props.num_exponent_bits - 1) - 1 - - result_sign = a_sign ^ b_sign - operand_exponent_sums = a_exponent + b_exponent - product_exponent = operand_exponent_sums - pyrtl.Const(exponent_bias) - - a_mantissa = pyrtl.concat( - pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_a_daz) - ) - b_mantissa = pyrtl.concat( - pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_b_daz) - ) - product_mantissa = a_mantissa * b_mantissa - - normalized_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) - normalized_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) - - need_to_normalize = product_mantissa[-1] - - if rounding_mode == RoundingMode.RNE: - guard = pyrtl.WireVector(bitwidth=1) - sticky = pyrtl.WireVector(bitwidth=1) - last = pyrtl.WireVector(bitwidth=1) - +def mul( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, +) -> pyrtl.WireVector: + """ + Performs floating point multiplication of two WireVectors. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the multiplication as a WireVector. + """ + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + + # Denormalized numbers are not supported, so we flush them to zero. + operands = (operand_a, operand_b) + operands_daz = tuple(make_denormals_zero(fp_type_props, op) for op in operands) + + # Extract the sign and exponent of both operands. + signs = tuple(get_sign(fp_type_props, op) for op in operands_daz) + exponents = tuple(get_exponent(fp_type_props, op) for op in operands_daz) + + result_sign = signs[0] ^ signs[1] + + # IEEE-754 floating point numbers have a bias: + # https://en.wikipedia.org/wiki/Exponent_bias + # real_exponent = stored_exponent - bias, so stored_exponent = real + bias + # Therefore, stored_exponent_product = real_exponent_product + bias + # = (real_exponent_a + real_exponent_b) + bias + # = (stored_exponent_a - bias + stored_exponent_b - bias) + bias + # = stored_exponent_a + stored_exponent_b - bias + operand_exponent_sums = exponents[0] + exponents[1] + exponent_bias = 2 ** (fp_type_props.num_exponent_bits - 1) - 1 + product_exponent = operand_exponent_sums - pyrtl.Const(exponent_bias) + + # Extract the mantissa of both operands and add the implicit leading 1. + mantissas = tuple( + pyrtl.concat(pyrtl.Const(1), get_mantissa(fp_type_props, op)) + for op in operands_daz + ) + product_mantissa = mantissas[0] * mantissas[1] + + normalized_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) + normalized_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + + # We need to normalize (shift right) if the leading bit is 1. + # https://numeral-systems.com/ieee-754-multiply/ + need_to_normalize = product_mantissa[-1] + + if rounding_mode == RoundingMode.RNE: + guard = pyrtl.WireVector(bitwidth=1) + sticky = pyrtl.WireVector(bitwidth=1) + last = pyrtl.WireVector(bitwidth=1) # Last bit of the mantissa before rounding. + + # Assign the normalized mantissa, exponent, guard, sticky, and last bits + # based on whether normalization is needed. + with pyrtl.conditional_assignment: + with need_to_normalize: + normalized_product_mantissa |= product_mantissa[-num_mant_bits - 1 :] + normalized_product_exponent |= product_exponent + 1 + if rounding_mode == RoundingMode.RNE: + guard |= product_mantissa[-num_mant_bits - 2] + sticky |= product_mantissa[: -num_mant_bits - 2] != 0 + last |= product_mantissa[-num_mant_bits - 1] + with pyrtl.otherwise: + normalized_product_mantissa |= product_mantissa[-num_mant_bits - 2 : -1] + normalized_product_exponent |= product_exponent + if rounding_mode == RoundingMode.RNE: + guard |= product_mantissa[-num_mant_bits - 3] + sticky |= product_mantissa[: -num_mant_bits - 3] != 0 + last |= product_mantissa[-num_mant_bits - 2] + + if rounding_mode == RoundingMode.RNE: + rounded_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + rounded_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) + # Whether exponent was incremented due to rounding (for overflow check). + exponent_incremented = pyrtl.WireVector(bitwidth=1) + # If guard bit is not set, number is closer to smaller value: no round. + # If guard and sticky are set, round up. + # If guard is set but sticky is not, value is exactly halfway. + # Following round-to-nearest ties-to-even, round up if last bit is 1. + round_up = guard & (last | sticky) with pyrtl.conditional_assignment: - with need_to_normalize: - normalized_product_mantissa |= product_mantissa[-num_mant_bits - 1 :] - normalized_product_exponent |= product_exponent + 1 - if rounding_mode == RoundingMode.RNE: - guard |= product_mantissa[-num_mant_bits - 2] - sticky |= product_mantissa[: -num_mant_bits - 2] != 0 - last |= product_mantissa[-num_mant_bits - 1] - with pyrtl.otherwise: - normalized_product_mantissa |= product_mantissa[-num_mant_bits - 2 : -1] - normalized_product_exponent |= product_exponent - if rounding_mode == RoundingMode.RNE: - guard |= product_mantissa[-num_mant_bits - 3] - sticky |= product_mantissa[: -num_mant_bits - 3] != 0 - last |= product_mantissa[-num_mant_bits - 2] - - if rounding_mode == RoundingMode.RNE: - rounded_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) - rounded_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) - exponent_incremented = pyrtl.WireVector(bitwidth=1) - with pyrtl.conditional_assignment: - with guard & (last | sticky): - with normalized_product_mantissa == (1 << num_mant_bits) - 1: - rounded_product_mantissa |= 0 - rounded_product_exponent |= normalized_product_exponent + 1 - exponent_incremented |= 1 - with pyrtl.otherwise: - rounded_product_mantissa |= normalized_product_mantissa + 1 - rounded_product_exponent |= normalized_product_exponent - exponent_incremented |= 0 + with round_up: + with normalized_product_mantissa == (1 << num_mant_bits) - 1: + rounded_product_mantissa |= 0 + rounded_product_exponent |= normalized_product_exponent + 1 + exponent_incremented |= 1 with pyrtl.otherwise: - rounded_product_mantissa |= normalized_product_mantissa + rounded_product_mantissa |= normalized_product_mantissa + 1 rounded_product_exponent |= normalized_product_exponent exponent_incremented |= 0 - - result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) - result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) - - operand_a_nan = FloatUtils.is_NaN(fp_type_props, operand_a_daz) - operand_b_nan = FloatUtils.is_NaN(fp_type_props, operand_b_daz) - operand_a_inf = FloatUtils.is_inf(fp_type_props, operand_a_daz) - operand_b_inf = FloatUtils.is_inf(fp_type_props, operand_b_daz) - operand_a_zero = FloatUtils.is_zero(fp_type_props, operand_a_daz) - operand_b_zero = FloatUtils.is_zero(fp_type_props, operand_b_daz) - operand_a_denormalized = FloatUtils.is_denormalized( - fp_type_props, operand_a_daz + with pyrtl.otherwise: + rounded_product_mantissa |= normalized_product_mantissa + rounded_product_exponent |= normalized_product_exponent + exponent_incremented |= 0 + + result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + + # Check whether operands are special: NaN, infinity, zero, or denormalized. + operand_nans = tuple(is_nan(fp_type_props, op) for op in operands_daz) + operand_infs = tuple(is_inf(fp_type_props, op) for op in operands_daz) + operand_zeros = tuple(is_zero(fp_type_props, op) for op in operands_daz) + operand_denorms = tuple(is_denormalized(fp_type_props, op) for op in operands_daz) + + # We check for overflow and underflow by computing max and min exponent + # values of the sum of operands before rounding and normalization. + # These values depend on the operands. If the result requires + # normalization, the exponent is incremented by 1. Additionally, rounding + # may further increase the exponent. Therefore, we subtract these + # potential increments from the absolute maximum exponent, which is one + # less than the all-1s exponent (reserved for inf/NaN) plus bias. + # Similarly, we subtract these increments from the absolute minimum + # exponent, which is 1 plus the exponent bias. + sum_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2 + exponent_bias) + sum_exponent_min_value = pyrtl.Const(1 + exponent_bias) + if rounding_mode == RoundingMode.RNE: + exponent_max_value = ( + sum_exponent_max_value - need_to_normalize - exponent_incremented ) - operand_b_denormalized = FloatUtils.is_denormalized( - fp_type_props, operand_b_daz + exponent_min_value = ( + sum_exponent_min_value - need_to_normalize - exponent_incremented ) - - # Overflow and underflow checks (only for normal cases) - sum_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2 + exponent_bias) - sum_exponent_min_value = pyrtl.Const(1 + exponent_bias) - if rounding_mode == RoundingMode.RNE: - exponent_max_value = ( - sum_exponent_max_value - need_to_normalize - exponent_incremented - ) - exponent_min_value = ( - sum_exponent_min_value - need_to_normalize - exponent_incremented - ) - else: - exponent_max_value = sum_exponent_max_value - need_to_normalize - exponent_min_value = sum_exponent_min_value - need_to_normalize - - if rounding_mode == RoundingMode.RNE: - raw_result_exponent = rounded_product_exponent[0:num_exp_bits] - raw_result_mantissa = rounded_product_mantissa - else: - raw_result_exponent = normalized_product_exponent[0:num_exp_bits] - raw_result_mantissa = normalized_product_mantissa - - with pyrtl.conditional_assignment: - # nan - with ( - operand_a_nan - | operand_b_nan - | (operand_a_inf & operand_b_zero) - | (operand_a_zero & operand_b_inf) - ): - FloatUtils.make_output_NaN( + else: + exponent_max_value = sum_exponent_max_value - need_to_normalize + exponent_min_value = sum_exponent_min_value - need_to_normalize + + # Assign the raw result's exponent and mantissa depending on whether RNE rounding + # is used. The calculated exponent WireVector has an extra bit due to the carry-out + # from addition, so we take only the lower num_exp_bits to remove this extra bit. + if rounding_mode == RoundingMode.RNE: + raw_result_exponent = rounded_product_exponent[:num_exp_bits] + raw_result_mantissa = rounded_product_mantissa + else: + raw_result_exponent = normalized_product_exponent[:num_exp_bits] + raw_result_mantissa = normalized_product_mantissa + + with pyrtl.conditional_assignment: + # If either operand is NaN, or if one operand is infinity and the other is + # zero, the result is NaN. + with ( + operand_nans[0] + | operand_nans[1] + | (operand_infs[0] & operand_zeros[1]) + | (operand_zeros[0] & operand_infs[1]) + ): + make_nan(fp_type_props, result_exponent, result_mantissa) + # If either operand is infinity, the result is infinity. + with operand_infs[0] | operand_infs[1]: + make_inf(fp_type_props, result_exponent, result_mantissa) + # Detect overflow. + with operand_exponent_sums > exponent_max_value: + if rounding_mode == RoundingMode.RNE: + make_inf(fp_type_props, result_exponent, result_mantissa) + else: + make_largest_finite_number( fp_type_props, result_exponent, result_mantissa ) - # infinity - with operand_a_inf | operand_b_inf: - FloatUtils.make_output_inf( - fp_type_props, result_exponent, result_mantissa - ) - # overflow - with operand_exponent_sums > exponent_max_value: - if rounding_mode == RoundingMode.RNE: - FloatUtils.make_output_inf( - fp_type_props, result_exponent, result_mantissa - ) - else: - FloatUtils.make_output_largest_finite_number( - fp_type_props, result_exponent, result_mantissa - ) - # zero or underflow - with ( - operand_a_zero - | operand_b_zero - | (operand_exponent_sums < exponent_min_value) - | operand_a_denormalized - | operand_b_denormalized - ): - FloatUtils.make_output_zero(result_exponent, result_mantissa) - with pyrtl.otherwise: - result_exponent |= raw_result_exponent - result_mantissa |= raw_result_mantissa - - return pyrtl.concat(result_sign, result_exponent, result_mantissa) + # If either operand is zero, if underflow occurred, or if either operand is + # denormalized, the result is zero. + with ( + operand_zeros[0] + | operand_zeros[1] + | (operand_exponent_sums < exponent_min_value) + | operand_denorms[0] + | operand_denorms[1] + ): + make_zero(result_exponent, result_mantissa) + with pyrtl.otherwise: + result_exponent |= raw_result_exponent + result_mantissa |= raw_result_mantissa + + return pyrtl.concat(result_sign, result_exponent, result_mantissa) diff --git a/pyrtl/rtllib/pyrtlfloat/_types.py b/pyrtl/rtllib/pyrtlfloat/_types.py index 15a1c811..a18df565 100644 --- a/pyrtl/rtllib/pyrtlfloat/_types.py +++ b/pyrtl/rtllib/pyrtlfloat/_types.py @@ -3,17 +3,43 @@ class RoundingMode(Enum): - RTZ = 1 # round towards zero (truncate) - RNE = 2 # round to nearest, ties to even (default mode) + """ + Enum representing different rounding modes. + + Attributes: + RTZ (int): Round towards zero (truncate). + RNE (int): Round to nearest, ties to even (default mode). + """ + + RTZ = 1 + RNE = 2 @dataclass(frozen=True) class FPTypeProperties: + """ + Data class representing properties of a floating-point type. + + Attributes: + num_exponent_bits (int): Number of bits used for the exponent. + num_mantissa_bits (int): Number of bits used for the mantissa. + """ + num_exponent_bits: int num_mantissa_bits: int class FloatingPointType(Enum): + """ + Enum representing different floating-point types. + + Attributes: + BFLOAT16 (FPTypeProperties): BFloat16 type properties. + FLOAT16 (FPTypeProperties): Float16 type properties. + FLOAT32 (FPTypeProperties): Float32 type properties. + FLOAT64 (FPTypeProperties): Float64 type properties. + """ + BFLOAT16 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=7) FLOAT16 = FPTypeProperties(num_exponent_bits=5, num_mantissa_bits=10) FLOAT32 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=23) @@ -22,9 +48,14 @@ class FloatingPointType(Enum): @dataclass(frozen=True) class PyrtlFloatConfig: - fp_type_properties: FPTypeProperties - rounding_mode: RoundingMode + """ + Data class representing the configuration for PyrtlFloat operations (floating point + type properties and rounding mode). + Attributes: + fp_type_properties (FPTypeProperties): Properties of the floating-point type. + rounding_mode (RoundingMode): Rounding mode to be used. + """ -class PyrtlFloatException(Exception): - pass + fp_type_properties: FPTypeProperties + rounding_mode: RoundingMode diff --git a/pyrtl/rtllib/pyrtlfloat/floatoperations.py b/pyrtl/rtllib/pyrtlfloat/floatoperations.py index ef081b0a..c3bdf98b 100644 --- a/pyrtl/rtllib/pyrtlfloat/floatoperations.py +++ b/pyrtl/rtllib/pyrtlfloat/floatoperations.py @@ -1,11 +1,38 @@ import pyrtl -from ._add_sub import AddSubHelper -from ._multiplication import MultiplicationHelper +from ._add_sub import add, sub +from ._multiplication import mul from ._types import FloatingPointType, PyrtlFloatConfig, RoundingMode +def _validate_operand_bitwidths( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, +) -> None: + """Validate that operand bitwidths match the floating point config.""" + fp_props = config.fp_type_properties + expected_bitwidth = fp_props.num_exponent_bits + fp_props.num_mantissa_bits + 1 + if operand_a.bitwidth != expected_bitwidth: + msg = ( + f"operand_a bitwidth {operand_a.bitwidth} does not match expected " + f"bitwidth {expected_bitwidth} for floating point type" + ) + raise pyrtl.PyrtlError(msg) + if operand_b.bitwidth != expected_bitwidth: + msg = ( + f"operand_b bitwidth {operand_b.bitwidth} does not match expected " + f"bitwidth {expected_bitwidth} for floating point type" + ) + raise pyrtl.PyrtlError(msg) + + class FloatOperations: + """ + The rounding mode used for typed floating-point operations. + To change it, set this variable to the desired RoundingMode value. + """ + default_rounding_mode = RoundingMode.RNE @staticmethod @@ -14,7 +41,19 @@ def mul( operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector, ) -> pyrtl.WireVector: - return MultiplicationHelper.multiply(config, operand_a, operand_b) + """ + Performs floating point multiplication of two WireVectors. The bitwidth of + the operands must be num_exponent_bits + num_mantissa_bits + 1, where + num_exponent_bits and num_mantissa_bits are defined in the config. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the multiplication as a WireVector. + :raises PyrtlError: If operand bitwidths don't match config. + """ + _validate_operand_bitwidths(config, operand_a, operand_b) + return mul(config, operand_a, operand_b) @staticmethod def add( @@ -22,7 +61,19 @@ def add( operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector, ) -> pyrtl.WireVector: - return AddSubHelper.add(config, operand_a, operand_b) + """ + Performs floating point addition of two WireVectors. The bitwidth of + the operands must be num_exponent_bits + num_mantissa_bits + 1, where + num_exponent_bits and num_mantissa_bits are defined in the config. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the addition as a WireVector. + :raises PyrtlError: If operand bitwidths don't match config. + """ + _validate_operand_bitwidths(config, operand_a, operand_b) + return add(config, operand_a, operand_b) @staticmethod def sub( @@ -30,7 +81,19 @@ def sub( operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector, ) -> pyrtl.WireVector: - return AddSubHelper.sub(config, operand_a, operand_b) + """ + Performs floating point subtraction of two WireVectors. The bitwidth of + the operands must be num_exponent_bits + num_mantissa_bits + 1, where + num_exponent_bits and num_mantissa_bits are defined in the config. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the subtraction as a WireVector. + :raises PyrtlError: If operand bitwidths don't match config. + """ + _validate_operand_bitwidths(config, operand_a, operand_b) + return sub(config, operand_a, operand_b) class _BaseTypedFloatOperations: @@ -40,18 +103,42 @@ class _BaseTypedFloatOperations: def mul( cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector ) -> pyrtl.WireVector: + """ + Performs floating point multiplication of two WireVectors. The bitwidth of + the operands must match the bitwidth of the floating point type of this class. + + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the multiplication as a WireVector. + """ return FloatOperations.mul(cls._get_config(), operand_a, operand_b) @classmethod def add( cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector ) -> pyrtl.WireVector: + """ + Performs floating point addition of two WireVectors. The bitwidth of + the operands must match the bitwidth of the floating point type of this class. + + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the addition as a WireVector. + """ return FloatOperations.add(cls._get_config(), operand_a, operand_b) @classmethod def sub( cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector ) -> pyrtl.WireVector: + """ + Performs floating point subtraction of two WireVectors. The bitwidth of + the operands must match the bitwidth of the floating point type of this class. + + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the subtraction as a WireVector. + """ return FloatOperations.sub(cls._get_config(), operand_a, operand_b) @classmethod @@ -62,16 +149,32 @@ def _get_config(cls) -> PyrtlFloatConfig: class BFloat16Operations(_BaseTypedFloatOperations): + """ + Operations for BFloat16 floating point type. + """ + _fp_type = FloatingPointType.BFLOAT16 class Float16Operations(_BaseTypedFloatOperations): + """ + Operations for Float16 floating point type. + """ + _fp_type = FloatingPointType.FLOAT16 class Float32Operations(_BaseTypedFloatOperations): + """ + Operations for Float32 floating point type. + """ + _fp_type = FloatingPointType.FLOAT32 class Float64Operations(_BaseTypedFloatOperations): + """ + Operations for Float64 floating point type. + """ + _fp_type = FloatingPointType.FLOAT64 diff --git a/tests/rtllib/pyrtlfloat/test_add_sub.py b/tests/rtllib/pyrtlfloat/test_add_sub.py index 3f006da2..831b80bf 100644 --- a/tests/rtllib/pyrtlfloat/test_add_sub.py +++ b/tests/rtllib/pyrtlfloat/test_add_sub.py @@ -3,23 +3,522 @@ import pyrtl from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode +# IEEE 754 Float16 special values +FLOAT16_POS_ZERO = 0x0000 +FLOAT16_NEG_ZERO = 0x8000 +FLOAT16_POS_INF = 0x7C00 +FLOAT16_NEG_INF = 0xFC00 +FLOAT16_NAN = 0x7E00 # Quiet NaN with mantissa bit set +FLOAT16_ONE = 0x3C00 # 1.0 +FLOAT16_NEG_ONE = 0xBC00 # -1.0 +FLOAT16_TWO = 0x4000 # 2.0 +FLOAT16_THREE = 0x4200 # 3.0 +FLOAT16_HALF = 0x3800 # 0.5 +FLOAT16_QUARTER = 0x3400 # 0.25 +FLOAT16_ONE_POINT_FIVE = 0x3E00 # 1.5 +FLOAT16_ONE_POINT_TWOFIVE = 0x3D00 # 1.25 +FLOAT16_LARGEST_NORMAL = 0x7BFF # Largest normal number (~65504) +FLOAT16_SMALLEST_NORMAL = 0x0400 # Smallest normal number (2^-14) +FLOAT16_DENORMAL = 0x0001 # Smallest denormal + + +def float16_parts(sign, exp, mant): + """Construct Float16 from sign, exponent, and mantissa.""" + return (sign << 15) | (exp << 10) | mant + + +def decode_float16(bits): + """Decode Float16 bits to sign, exponent, and mantissa.""" + return (bits >> 15) & 1, (bits >> 10) & 0x1F, bits & 0x3FF + + +def is_nan(bits): + """Check if Float16 bits represent NaN.""" + _, exp, mant = decode_float16(bits) + return exp == 0x1F and mant != 0 + + +class TestAdditionNormalCases(unittest.TestCase): + """Tests for normal Float16 addition operations.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.add(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.add(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_add_one_plus_two(self): + """Test 1.0 + 2.0 = 3.0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_TWO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_THREE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_THREE) + + def test_add_one_plus_half(self): + """Test 1.0 + 0.5 = 1.5 (no rounding, GRS=000)""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_HALF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE_POINT_FIVE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE_POINT_FIVE) + + def test_add_one_plus_quarter(self): + """Test 1.0 + 0.25 = 1.25 (no rounding, shift by 2)""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_QUARTER}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE_POINT_TWOFIVE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE_POINT_TWOFIVE) + + def test_add_half_plus_half(self): + """Test 0.5 + 0.5 = 1.0""" + self.sim.step({"a": FLOAT16_HALF, "b": FLOAT16_HALF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE) + + def test_add_with_carry(self): + """Test 1.5 + 1.5 = 3.0 (carry propagates to exponent)""" + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_ONE_POINT_FIVE}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_THREE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_THREE) + + def test_add_opposite_signs_equal_magnitude(self): + """Test 1.0 + (-1.0) = 0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_ONE}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_ZERO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_ZERO) + + +class TestAdditionRounding(unittest.TestCase): + """Tests for rounding in addition (RNE and RTZ).""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.add(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.add(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_rounding_g1_r0_s0_lsb0_tie_truncates(self): + """Test G=1, R=0, S=0, LSB=0: tie, RNE truncates to even. + + a = 1.0 (exp=15, mant=0) + b = 0.5 * (1 + 1/1024) = exp=14, mant=1 + Shift b by 1: G=1 (bit 0 of original), R=0, S=0 + Sum mantissa LSB = 0, so RNE truncates. + Both RNE and RTZ produce same result. + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 14, 1) # 0.5 * (1 + 1/1024) + expected = float16_parts(0, 15, 512) # 1.5 + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected) + self.assertEqual(self.sim.inspect("result_rtz"), expected) + + def test_rounding_g1_r0_s0_lsb1_tie_rounds_up(self): + """Test G=1, R=0, S=0, LSB=1: tie, RNE rounds up to even. + + a = 1.0 * (1 + 1/1024) = exp=15, mant=1 + b = 0.5 * (1 + 1/1024) = exp=14, mant=1 + Shift b by 1: G=1, R=0, S=0 + Sum mantissa = 1.1000000001, LSB = 1 + RNE: round up to make LSB even -> mant = 514 + RTZ: truncate -> mant = 513 + """ + a = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + b = float16_parts(0, 14, 1) # 0.5 * (1 + 1/1024) + expected_rne = float16_parts(0, 15, 514) + expected_rtz = float16_parts(0, 15, 513) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected_rne) + self.assertEqual(self.sim.inspect("result_rtz"), expected_rtz) + + def test_rounding_g1_r1_s0_rounds_up(self): + """Test G=1, R=1, S=0: greater than half ULP, RNE rounds up. + + a = 1.0 (exp=15, mant=0) + b = 0.25 * (1 + 3/1024) = exp=13, mant=3 + Shift b by 2: G=1 (bit 1), R=1 (bit 0), S=0 + RNE: round up + RTZ: truncate + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 13, 3) # 0.25 * (1 + 3/1024) + expected_rne = float16_parts(0, 15, 257) + expected_rtz = float16_parts(0, 15, 256) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected_rne) + self.assertEqual(self.sim.inspect("result_rtz"), expected_rtz) + + def test_rounding_g1_r0_s1_rounds_up(self): + """Test G=1, R=0, S=1: greater than half ULP, RNE rounds up. + + a = 1.0 (exp=15, mant=0) + b = 0.125 * (1 + 5/1024) = exp=12, mant=5 (binary: 101) + Shift b by 3: G=1 (bit 2), R=0 (bit 1), S=1 (bit 0) + RNE: round up + RTZ: truncate + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 12, 5) # 0.125 * (1 + 5/1024) + expected_rne = float16_parts(0, 15, 129) + expected_rtz = float16_parts(0, 15, 128) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected_rne) + self.assertEqual(self.sim.inspect("result_rtz"), expected_rtz) + + def test_rounding_g0_r1_s1_truncates(self): + """Test G=0, R=1, S=1: less than half ULP, RNE truncates. + + a = 1.0 (exp=15, mant=0) + b = 0.125 * (1 + 3/1024) = exp=12, mant=3 (binary: 011) + Shift b by 3: G=0 (bit 2), R=1 (bit 1), S=1 (bit 0) + Both RNE and RTZ truncate. + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 12, 3) # 0.125 * (1 + 3/1024) + expected = float16_parts(0, 15, 128) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected) + self.assertEqual(self.sim.inspect("result_rtz"), expected) + + +class TestAdditionRoundingWithCarry(unittest.TestCase): + """Tests for rounding in addition with carry (overflow into exponent). + + When the sum mantissa overflows (carry out), we shift right and + increment the exponent. This creates new GRS bits. + """ + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.add(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.add(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_carry_g1_r0_s0_lsb1_tie_rounds_up(self): + """Test carry with G=1, R=0, S=0, LSB=1: tie rounds up. + + a = 1.1111111110 (exp=15, mant=1022) + b = 1.0000000001 (exp=15, mant=1) + Sum = 10.1111111111 -> normalize to 1.01111111111 + After normalization: G=1 (shifted out bit), R=0, S=0 + Result mantissa = 0111111111 = 511, LSB = 1 + RNE: tie, LSB=1 -> round up to 512 + RTZ: truncate to 511 + """ + a = float16_parts(0, 15, 1022) + b = float16_parts(0, 15, 1) + expected_rne = float16_parts(0, 16, 512) # Rounded up + expected_rtz = float16_parts(0, 16, 511) # Truncated + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected_rne) + self.assertEqual(self.sim.inspect("result_rtz"), expected_rtz) + + def test_carry_g1_r0_s0_lsb0_tie_truncates(self): + """Test carry with G=1, R=0, S=0, LSB=0: tie truncates. + + a = 1.1111111100 (exp=15, mant=1020) + b = 1.0000000001 (exp=15, mant=1) + After normalization and carry handling: + Result mantissa = 510, LSB = 0 + Both RNE and RTZ truncate. + """ + a = float16_parts(0, 15, 1020) + b = float16_parts(0, 15, 1) + expected = float16_parts(0, 16, 510) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected) + self.assertEqual(self.sim.inspect("result_rtz"), expected) + + +class TestAdditionEdgeCases(unittest.TestCase): + """Tests for edge cases in Float16 addition.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.add(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.add(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_add_zero_to_number(self): + """Test x + 0 = x""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_ZERO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE) + + def test_add_negative_zero_to_number(self): + """Test x + (-0) = x""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_ZERO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE) + + def test_add_infinity_to_number(self): + """Test x + inf = inf""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_INF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_INF) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_INF) + + def test_add_negative_infinity_to_number(self): + """Test x + (-inf) = -inf""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_INF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_NEG_INF) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_NEG_INF) + + def test_add_infinity_minus_infinity_is_nan(self): + """Test inf + (-inf) = NaN""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_NEG_INF}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_add_nan_propagates(self): + """Test x + NaN = NaN""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NAN}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_add_denormal_flushed_to_zero(self): + """Test that denormal inputs are flushed to zero.""" + self.sim.step({"a": FLOAT16_POS_ZERO, "b": FLOAT16_DENORMAL}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_ZERO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_ZERO) + + +class TestAdditionOverflow(unittest.TestCase): + """Tests for overflow handling in Float16 addition.""" -class TestMultiplication(unittest.TestCase): def setUp(self): pyrtl.reset_working_block() - a = pyrtl.Input(bitwidth=16, name="a") - b = pyrtl.Input(bitwidth=16, name="b") + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") FloatOperations.default_rounding_mode = RoundingMode.RNE - result_add = pyrtl.Output(name="result_add") - result_add <<= Float16Operations.add(a, b) - result_sub = pyrtl.Output(name="result_sub") - result_sub <<= Float16Operations.sub(a, b) + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.add(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.add(self.a, self.b) self.sim = pyrtl.Simulation() - def test_multiplication_simple(self): - self.sim.step({"a": 0b0100001000000000, "b": 0b0100010100000000}) - self.assertEqual(self.sim.inspect("result_add"), 0b0100100000000000) - self.assertEqual(self.sim.inspect("result_sub"), 0b1100000000000000) + def test_overflow_rne_produces_infinity(self): + """Test that overflow produces infinity with RNE.""" + self.sim.step({"a": FLOAT16_LARGEST_NORMAL, "b": FLOAT16_LARGEST_NORMAL}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_INF) + + def test_overflow_rtz_produces_largest_finite(self): + """Test that overflow produces largest finite with RTZ.""" + self.sim.step({"a": FLOAT16_LARGEST_NORMAL, "b": FLOAT16_LARGEST_NORMAL}) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_LARGEST_NORMAL) + + +class TestSubtractionNormalCases(unittest.TestCase): + """Tests for normal Float16 subtraction operations.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.sub(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.sub(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_sub_three_minus_one(self): + """Test 3.0 - 1.0 = 2.0""" + self.sim.step({"a": FLOAT16_THREE, "b": FLOAT16_ONE}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_TWO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_TWO) + + def test_sub_one_point_five_minus_half(self): + """Test 1.5 - 0.5 = 1.0""" + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_HALF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE) + + def test_sub_equal_numbers(self): + """Test 1.0 - 1.0 = 0.0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_ONE}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_ZERO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_ZERO) + + def test_sub_from_zero(self): + """Test 0 - 1.0 = -1.0""" + self.sim.step({"a": FLOAT16_POS_ZERO, "b": FLOAT16_ONE}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_NEG_ONE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_NEG_ONE) + + def test_sub_two_minus_half(self): + """Test 2.0 - 0.5 = 1.5""" + self.sim.step({"a": FLOAT16_TWO, "b": FLOAT16_HALF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE_POINT_FIVE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE_POINT_FIVE) + + def test_sub_double_negative(self): + """Test x - (-y) = x + y: 1.0 - (-1.0) = 2.0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_ONE}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_TWO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_TWO) + + +class TestSubtractionRounding(unittest.TestCase): + """Tests for rounding in Float16 subtraction. + + Subtraction is implemented as a - b = a + (-b), so GRS bits come from + shifting the smaller operand right and from subtraction normalization. + """ + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.sub(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.sub(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_sub_exact_no_rounding(self): + """Test 1.5 - 0.25 = 1.25 (no rounding needed, exact result). + + a = 1.5 (exp=15, mant=512) + b = 0.25 (exp=13, mant=0) + Result fits exactly in mantissa, GRS=000. + Both RNE and RTZ produce same result. + """ + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_QUARTER}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE_POINT_TWOFIVE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE_POINT_TWOFIVE) + + def test_sub_rne_rounds_up(self): + """Test subtraction where RNE rounds up and RTZ truncates. + + a = 1.0 * (1 + 512/1024) = exp=15, mant=512 (1.5) + b = 0.25 * (1 + 2/1024) = exp=13, mant=2 + Shift b by 2 -> LSB=1, G=1, R=0, S=0 + RNE: rounds up to 256 + RTZ: truncates to 255 + """ + a = float16_parts(0, 15, 512) + b = float16_parts(0, 13, 2) + expected_rne = float16_parts(0, 15, 256) + expected_rtz = float16_parts(0, 15, 255) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected_rne) + self.assertEqual(self.sim.inspect("result_rtz"), expected_rtz) + + +class TestSubtractionEdgeCases(unittest.TestCase): + """Tests for edge cases in Float16 subtraction.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.sub(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.sub(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_sub_zero_from_number(self): + """Test x - 0 = x""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_ZERO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE) + + def test_sub_pos_zero_minus_pos_zero(self): + """Test +0 - +0 = +0""" + self.sim.step({"a": FLOAT16_POS_ZERO, "b": FLOAT16_POS_ZERO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_ZERO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_ZERO) + + def test_sub_neg_zero_minus_neg_zero(self): + """Test -0 - -0 = +0""" + self.sim.step({"a": FLOAT16_NEG_ZERO, "b": FLOAT16_NEG_ZERO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_ZERO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_ZERO) + + def test_sub_infinity_from_number(self): + """Test x - inf = -inf""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_INF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_NEG_INF) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_NEG_INF) + + def test_sub_infinity_from_infinity_is_nan(self): + """Test inf - inf = NaN""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_POS_INF}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_sub_neg_infinity_from_pos_infinity(self): + """Test inf - (-inf) = inf""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_NEG_INF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_INF) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_INF) + + def test_sub_nan_propagates(self): + """Test x - NaN = NaN""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NAN}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_sub_denormal_flushed_to_zero(self): + """Test that denormal operands are flushed to zero.""" + self.sim.step({"a": FLOAT16_DENORMAL, "b": FLOAT16_POS_ZERO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_ZERO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_ZERO) + + +class TestSubtractionOverflow(unittest.TestCase): + """Tests for overflow handling in Float16 subtraction.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.sub(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.sub(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_overflow_by_subtracting_negative(self): + """Test overflow when subtracting large negative: large - (-large). + + RNE: overflow produces infinity + RTZ: overflow produces largest finite + """ + a = FLOAT16_LARGEST_NORMAL # Large positive + b = FLOAT16_LARGEST_NORMAL | 0x8000 # Same magnitude, negative + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_INF) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_LARGEST_NORMAL) if __name__ == "__main__": diff --git a/tests/rtllib/pyrtlfloat/test_multiplication.py b/tests/rtllib/pyrtlfloat/test_multiplication.py index 439efabb..60128865 100644 --- a/tests/rtllib/pyrtlfloat/test_multiplication.py +++ b/tests/rtllib/pyrtlfloat/test_multiplication.py @@ -3,24 +3,290 @@ import pyrtl from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode +# IEEE 754 Float16 special values +FLOAT16_POS_ZERO = 0x0000 +FLOAT16_NEG_ZERO = 0x8000 +FLOAT16_POS_INF = 0x7C00 +FLOAT16_NEG_INF = 0xFC00 +FLOAT16_NAN = 0x7E00 # Quiet NaN with mantissa bit set +FLOAT16_ONE = 0x3C00 # 1.0 +FLOAT16_NEG_ONE = 0xBC00 # -1.0 +FLOAT16_TWO = 0x4000 # 2.0 +FLOAT16_NEG_TWO = 0xC000 # -2.0 +FLOAT16_THREE = 0x4200 # 3.0 +FLOAT16_HALF = 0x3800 # 0.5 +FLOAT16_ONE_POINT_FIVE = 0x3E00 # 1.5 +FLOAT16_LARGEST_NORMAL = 0x7BFF # Largest normal number (~65504) +FLOAT16_DENORMAL = 0x0001 # Smallest denormal + + +def float16_parts(sign, exp, mant): + """Construct Float16 from sign, exponent, and mantissa.""" + return (sign << 15) | (exp << 10) | mant + + +def decode_float16(bits): + """Decode Float16 bits to sign, exponent, and mantissa.""" + return (bits >> 15) & 1, (bits >> 10) & 0x1F, bits & 0x3FF + + +def is_nan(bits): + """Check if Float16 bits represent NaN.""" + _, exp, mant = decode_float16(bits) + return exp == 0x1F and mant != 0 + + +class TestMultiplicationNormalCases(unittest.TestCase): + """Tests for normal Float16 multiplication operations.""" -class TestMultiplication(unittest.TestCase): def setUp(self): pyrtl.reset_working_block() - a = pyrtl.Input(bitwidth=16, name="a") - b = pyrtl.Input(bitwidth=16, name="b") + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") FloatOperations.default_rounding_mode = RoundingMode.RNE result_rne = pyrtl.Output(name="result_rne") - result_rne <<= Float16Operations.mul(a, b) + result_rne <<= Float16Operations.mul(self.a, self.b) FloatOperations.default_rounding_mode = RoundingMode.RTZ result_rtz = pyrtl.Output(name="result_rtz") - result_rtz <<= Float16Operations.mul(a, b) + result_rtz <<= Float16Operations.mul(self.a, self.b) self.sim = pyrtl.Simulation() - def test_multiplication_simple(self): - self.sim.step({"a": 0b0011111000000000, "b": 0b0011110000000001}) - self.assertEqual(self.sim.inspect("result_rne"), 0b0011111000000010) - self.assertEqual(self.sim.inspect("result_rtz"), 0b0011111000000001) + def test_mul_half_times_two(self): + """Test 0.5 * 2.0 = 1.0""" + self.sim.step({"a": FLOAT16_HALF, "b": FLOAT16_TWO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_ONE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_ONE) + + def test_mul_one_point_five_times_two(self): + """Test 1.5 * 2.0 = 3.0""" + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_TWO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_THREE) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_THREE) + + def test_mul_opposite_signs(self): + """Test -1.0 * 2.0 = -2.0""" + self.sim.step({"a": FLOAT16_NEG_ONE, "b": FLOAT16_TWO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_NEG_TWO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_NEG_TWO) + + def test_mul_both_negative(self): + """Test -1.0 * -2.0 = 2.0""" + self.sim.step({"a": FLOAT16_NEG_ONE, "b": FLOAT16_NEG_TWO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_TWO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_TWO) + + def test_mul_one_point_five_times_one_point_five(self): + """Test 1.5 * 1.5 = 2.25""" + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_ONE_POINT_FIVE}) + expected = 0x4080 # 2.25 in float16 + self.assertEqual(self.sim.inspect("result_rne"), expected) + self.assertEqual(self.sim.inspect("result_rtz"), expected) + + +class TestMultiplicationRounding(unittest.TestCase): + """Tests for rounding in multiplication (RNE and RTZ).""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.mul(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.mul(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_rounding_g0_s0_truncates(self): + """Test Guard=0, Sticky=0: RNE truncates (exact result, no rounding needed). + + a = 1.0 (exp=15, mant=0) + b = 1.0 (exp=15, mant=0) + Product mantissa has Guard=0, Sticky=0. + Both RNE and RTZ produce the same result. + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 15, 0) # 1.0 + expected = float16_parts(0, 15, 0) # 1.0 + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected) + self.assertEqual(self.sim.inspect("result_rtz"), expected) + + def test_rounding_g0_s1_truncates(self): + """Test Guard=0, Sticky=1: RNE truncates (less than half ULP). + + a = 1.0 * (1 + 1/1024) = exp=15, mant=1 + b = 1.0 * (1 + 1/1024) = exp=15, mant=1 + Product has Guard=0, Sticky=1, Last=0. + Both RNE and RTZ truncate. + """ + a = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + b = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + expected = float16_parts(0, 15, 2) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected) + self.assertEqual(self.sim.inspect("result_rtz"), expected) + + def test_rounding_g1_l0_s0_tie_truncates(self): + """Test Guard=1, Last=0, Sticky=0: tie, RNE truncates (LSB already even). + + a = 1.0 * (1 + 2/1024) = exp=15, mant=2 + b = 1.0 * (1 + 256/1024) = exp=15, mant=256 + Product has Guard=1, Sticky=0, Last=0. + RNE: LSB is 0 (even), so truncate. + Both RNE and RTZ truncate. + """ + a = float16_parts(0, 15, 2) # 1.0 * (1 + 2/1024) + b = float16_parts(0, 15, 256) # 1.0 * (1 + 256/1024) + expected = float16_parts(0, 15, 258) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected) + self.assertEqual(self.sim.inspect("result_rtz"), expected) + + def test_rounding_g1_l1_s0_tie_rounds_up(self): + """Test Guard=1, Last=1, Sticky=0: tie, RNE rounds up (make LSB even). + + a = 1.0 * (1 + 1/1024) = exp=15, mant=1 + b = 1.0 * (1 + 512/1024) = exp=15, mant=512 (1.5) + Product has Guard=1, Sticky=0, Last=1. + RNE: LSB is 1 (odd), so round up to make it even. + RTZ: truncates + """ + a = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + b = float16_parts(0, 15, 512) # 1.0 * (1 + 512/1024) + expected_rne = float16_parts(0, 15, 514) + expected_rtz = float16_parts(0, 15, 513) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected_rne) + self.assertEqual(self.sim.inspect("result_rtz"), expected_rtz) + + def test_rounding_g1_l0_s1_rounds_up(self): + """Test Guard=1, Last=0, Sticky=1: greater than half ULP, RNE rounds up. + + a = 1.0 * (1 + 1/1024) = exp=15, mant=1 + b = 1.0 * (1 + 513/1024) = exp=15, mant=513 + Product has Guard=1, Sticky=1, Last=0. + Greater than half ULP, so round up. + RNE: rounds up + RTZ: truncates + """ + a = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + b = float16_parts(0, 15, 513) # 1.0 * (1 + 513/1024) + expected_rne = float16_parts(0, 15, 515) + expected_rtz = float16_parts(0, 15, 514) + self.sim.step({"a": a, "b": b}) + self.assertEqual(self.sim.inspect("result_rne"), expected_rne) + self.assertEqual(self.sim.inspect("result_rtz"), expected_rtz) + + +class TestMultiplicationEdgeCases(unittest.TestCase): + """Tests for edge cases in Float16 multiplication.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.mul(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.mul(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_mul_by_zero(self): + """Test x * 0 = 0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_ZERO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_ZERO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_ZERO) + + def test_mul_by_negative_zero(self): + """Test x * (-0) = -0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_ZERO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_NEG_ZERO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_NEG_ZERO) + + def test_mul_infinity_by_number(self): + """Test inf * x = inf""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_TWO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_INF) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_INF) + + def test_mul_neg_infinity_by_number(self): + """Test -inf * x = -inf""" + self.sim.step({"a": FLOAT16_NEG_INF, "b": FLOAT16_TWO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_NEG_INF) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_NEG_INF) + + def test_mul_infinity_by_zero_is_nan(self): + """Test inf * 0 = NaN""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_POS_ZERO}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_mul_infinity_by_infinity(self): + """Test inf * inf = inf""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_POS_INF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_INF) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_INF) + + def test_mul_pos_infinity_by_neg_infinity(self): + """Test inf * (-inf) = -inf""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_NEG_INF}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_NEG_INF) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_NEG_INF) + + def test_mul_nan_propagates(self): + """Test x * NaN = NaN""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NAN}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_mul_denormal_flushed_to_zero(self): + """Test that denormal operands are flushed to zero.""" + self.sim.step({"a": FLOAT16_DENORMAL, "b": FLOAT16_ONE}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_ZERO) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_POS_ZERO) + + +class TestMultiplicationOverflow(unittest.TestCase): + """Tests for overflow handling in Float16 multiplication.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.mul(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.mul(self.a, self.b) + self.sim = pyrtl.Simulation() + + def test_overflow_rne_produces_infinity(self): + """Test that overflow produces infinity with RNE.""" + self.sim.step({"a": FLOAT16_LARGEST_NORMAL, "b": FLOAT16_TWO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_POS_INF) + + def test_overflow_rtz_produces_largest_finite(self): + """Test that overflow produces largest finite with RTZ.""" + self.sim.step({"a": FLOAT16_LARGEST_NORMAL, "b": FLOAT16_TWO}) + self.assertEqual(self.sim.inspect("result_rtz"), FLOAT16_LARGEST_NORMAL) + + def test_negative_overflow_rne_produces_neg_infinity(self): + """Test that negative overflow produces -infinity with RNE.""" + neg_largest = FLOAT16_LARGEST_NORMAL | 0x8000 + self.sim.step({"a": neg_largest, "b": FLOAT16_TWO}) + self.assertEqual(self.sim.inspect("result_rne"), FLOAT16_NEG_INF) + + def test_negative_overflow_rtz_produces_neg_largest_finite(self): + """Test that negative overflow produces -largest finite with RTZ.""" + neg_largest = FLOAT16_LARGEST_NORMAL | 0x8000 + self.sim.step({"a": neg_largest, "b": FLOAT16_TWO}) + expected = FLOAT16_LARGEST_NORMAL | 0x8000 + self.assertEqual(self.sim.inspect("result_rtz"), expected) if __name__ == "__main__":