|
| 1 | +#[test_only] |
| 2 | +module openzeppelin_fp_math::sd29x9_sqrt_tests; |
| 3 | + |
| 4 | +use openzeppelin_fp_math::sd29x9; |
| 5 | +use openzeppelin_fp_math::sd29x9_base; |
| 6 | +use openzeppelin_fp_math::sd29x9_test_helpers::{pos, neg}; |
| 7 | +use std::unit_test::assert_eq; |
| 8 | + |
| 9 | +const SCALE: u128 = 1_000_000_000; |
| 10 | +const MAX_POSITIVE_VALUE: u128 = 0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF; |
| 11 | + |
| 12 | +// ==== Tests ==== |
| 13 | + |
| 14 | +#[test] |
| 15 | +fun sqrt_of_zero_is_zero() { |
| 16 | + assert_eq!(sd29x9::zero().sqrt(), sd29x9::zero()); |
| 17 | +} |
| 18 | + |
| 19 | +#[test] |
| 20 | +fun sqrt_of_positive_one() { |
| 21 | + assert_eq!(sd29x9::one().sqrt(), sd29x9::one()); |
| 22 | +} |
| 23 | + |
| 24 | +#[test] |
| 25 | +fun sqrt_of_positive_perfect_squares() { |
| 26 | + // sqrt(+4.0) = +2.0 |
| 27 | + assert_eq!(pos(4 * SCALE).sqrt(), pos(2 * SCALE)); |
| 28 | + // sqrt(+9.0) = +3.0 |
| 29 | + assert_eq!(pos(9 * SCALE).sqrt(), pos(3 * SCALE)); |
| 30 | + // sqrt(+25.0) = +5.0 |
| 31 | + assert_eq!(pos(25 * SCALE).sqrt(), pos(5 * SCALE)); |
| 32 | + // sqrt(+100.0) = +10.0 |
| 33 | + assert_eq!(pos(100 * SCALE).sqrt(), pos(10 * SCALE)); |
| 34 | + // sqrt(+10000.0) = +100.0 |
| 35 | + assert_eq!(pos(10_000 * SCALE).sqrt(), pos(100 * SCALE)); |
| 36 | +} |
| 37 | + |
| 38 | +#[test] |
| 39 | +fun sqrt_of_positive_fractional_squares() { |
| 40 | + // sqrt(+0.25) = +0.5 |
| 41 | + assert_eq!(pos(250_000_000).sqrt(), pos(500_000_000)); |
| 42 | + // sqrt(+0.01) = +0.1 |
| 43 | + assert_eq!(pos(10_000_000).sqrt(), pos(100_000_000)); |
| 44 | + // sqrt(+2.25) = +1.5 |
| 45 | + assert_eq!(pos(2_250_000_000).sqrt(), pos(1_500_000_000)); |
| 46 | +} |
| 47 | + |
| 48 | +#[test] |
| 49 | +fun sqrt_truncates_irrational_results() { |
| 50 | + // sqrt(+2.0) = +1.414213562 (truncated) |
| 51 | + assert_eq!(pos(2 * SCALE).sqrt(), pos(1_414_213_562)); |
| 52 | + // sqrt(+3.0) = +1.732050807 (truncated) |
| 53 | + assert_eq!(pos(3 * SCALE).sqrt(), pos(1_732_050_807)); |
| 54 | + // sqrt(+5.0) = +2.236067977 (truncated) |
| 55 | + assert_eq!(pos(5 * SCALE).sqrt(), pos(2_236_067_977)); |
| 56 | +} |
| 57 | + |
| 58 | +#[test] |
| 59 | +fun sqrt_of_max_positive() { |
| 60 | + // sqrt(sd29x9::max()) should not abort and satisfy the floor property |
| 61 | + let result = sd29x9::max().sqrt(); |
| 62 | + let r = result.unwrap() as u256; |
| 63 | + let max_scaled = (sd29x9::max().unwrap() as u256) * (SCALE as u256); |
| 64 | + assert!(r * r <= max_scaled); |
| 65 | + assert!((r + 1) * (r + 1) > max_scaled); |
| 66 | +} |
| 67 | + |
| 68 | +#[random_test] |
| 69 | +fun sqrt_result_is_always_non_negative(raw: u128) { |
| 70 | + let raw = raw % (MAX_POSITIVE_VALUE + 1); |
| 71 | + let result = sd29x9::wrap(raw, false).sqrt(); |
| 72 | + // Result is non-negative: raw bits should not have sign bit set |
| 73 | + assert!(result.unwrap() <= MAX_POSITIVE_VALUE); |
| 74 | +} |
| 75 | + |
| 76 | +#[random_test] |
| 77 | +fun sqrt_floor_invariant(raw: u128) { |
| 78 | + let raw = raw % (MAX_POSITIVE_VALUE + 1); |
| 79 | + let result = sd29x9::wrap(raw, false).sqrt(); |
| 80 | + // Floor property: r^2 <= x * SCALE < (r + 1)^2 |
| 81 | + let r = result.unwrap() as u256; |
| 82 | + let scaled = (raw as u256) * (SCALE as u256); |
| 83 | + assert!(r * r <= scaled); |
| 84 | + assert!((r + 1) * (r + 1) > scaled); |
| 85 | +} |
| 86 | + |
| 87 | +#[test] |
| 88 | +fun sqrt_squared_roundtrip_for_perfect_squares() { |
| 89 | + let values = vector[ |
| 90 | + pos(4 * SCALE), |
| 91 | + pos(9 * SCALE), |
| 92 | + pos(25 * SCALE), |
| 93 | + pos(250_000_000), // 0.25 |
| 94 | + ]; |
| 95 | + values.destroy!(|x| { |
| 96 | + let root = x.sqrt(); |
| 97 | + assert_eq!(root.mul(root), x); |
| 98 | + }); |
| 99 | +} |
| 100 | + |
| 101 | +#[test, expected_failure(abort_code = sd29x9_base::ENegativeSqrt)] |
| 102 | +fun sqrt_of_negative_aborts() { |
| 103 | + neg(SCALE).sqrt(); |
| 104 | +} |
| 105 | + |
| 106 | +#[test, expected_failure(abort_code = sd29x9_base::ENegativeSqrt)] |
| 107 | +fun sqrt_of_small_negative_aborts() { |
| 108 | + neg(1).sqrt(); |
| 109 | +} |
| 110 | + |
| 111 | +#[test, expected_failure(abort_code = sd29x9_base::ENegativeSqrt)] |
| 112 | +fun sqrt_of_min_value_aborts() { |
| 113 | + sd29x9::min().sqrt(); |
| 114 | +} |
0 commit comments