Skip to content

Commit 646fafd

Browse files
immrsdbidzyyys
andauthored
Fixed point sqrt (#286)
* Add fixed-point helper for sqrt calculation * Implement sqrt for SD29x9 type * Implement sqrt for UD30x9 * Add tests for fixed-point sqrt impl * Add sqrt and remaining functions to README * Format files * Use sqrt_floor implementation from math/core package * Address review comments * Move sqrt_floor function from fp_helpers to common module * Fix tests to use assert_eq instead of deprecated expect * Remove redundant zero check in ud30x9::sqrt * Format files * Address review comments --------- Co-authored-by: Daniel Bigos <daniel.bigos@openzeppelin.com>
1 parent 6ec50a6 commit 646fafd

8 files changed

Lines changed: 310 additions & 6 deletions

File tree

math/fixed_point/Move.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,8 @@
22
name = "openzeppelin_fp_math"
33
edition = "2024"
44

5+
[dependencies]
6+
openzeppelin_math = { local = "../core" }
7+
58
[addresses]
69
openzeppelin_fp_math = "0x0"

math/fixed_point/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Fixed-point decimal types with 9 decimals (10^9), matching Sui coin precision.
99

1010
## Operations
1111

12-
- Arithmetic: `add`, `sub`, `mul`, `mul_trunc`, `mul_away`, `div`, `div_trunc`, `div_away`, `pow`, `unchecked_add`, `unchecked_sub`, `mod`
12+
- Arithmetic: `add`, `sub`, `mul`, `mul_trunc`, `mul_away`, `div`, `div_trunc`, `div_away`, `pow`, `unchecked_add`, `unchecked_sub`, `mod`, `sqrt`
1313
- Comparison: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `is_zero`
1414
- `UD30x9` also exposes bitwise helpers: `and`, `and2`, `or`, `xor`, `not`, `lshift`, `rshift`, `unchecked_lshift`, `unchecked_rshift`
1515

@@ -34,6 +34,12 @@ Rule of thumb:
3434
The core `wrap` / `unwrap` APIs are **raw casts**. They preserve the
3535
underlying fixed-point representation and do not multiply or divide by `10^9`.
3636

37+
- `u128 -> UD30x9`: `into_UD30x9`
38+
- `UD30x9 -> SD29x9`: `into_SD29x9`, `try_into_SD29x9`
39+
- `SD29x9 -> UD30x9`: `into_UD30x9`, `try_into_UD30x9`
40+
- Constructors: `zero`, `one`, `max`, `wrap`
41+
- `SD29x9` only: `min`, `from_bits`
42+
3743
```rust
3844
use openzeppelin_fp_math::{sd29x9, ud30x9};
3945

math/fixed_point/sources/sd29x9/sd29x9.move

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public use fun openzeppelin_fp_math::sd29x9_base::mul_away as SD29x9.mul_away;
4444
public use fun openzeppelin_fp_math::sd29x9_base::mul_trunc as SD29x9.mul_trunc;
4545
public use fun openzeppelin_fp_math::sd29x9_base::negate as SD29x9.negate;
4646
public use fun openzeppelin_fp_math::sd29x9_base::pow as SD29x9.pow;
47+
public use fun openzeppelin_fp_math::sd29x9_base::sqrt as SD29x9.sqrt;
4748
public use fun openzeppelin_fp_math::sd29x9_base::rem as SD29x9.rem;
4849
public use fun openzeppelin_fp_math::sd29x9_base::sub as SD29x9.sub;
4950
public use fun openzeppelin_fp_math::sd29x9_base::unchecked_add as SD29x9.unchecked_add;

math/fixed_point/sources/sd29x9/sd29x9_base.move

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ module openzeppelin_fp_math::sd29x9_base;
66
use openzeppelin_fp_math::common;
77
use openzeppelin_fp_math::sd29x9::{SD29x9, from_bits, zero, min, one, two_complement, wrap};
88
use openzeppelin_fp_math::ud30x9::{Self, UD30x9};
9+
use openzeppelin_math::rounding;
10+
use openzeppelin_math::u256;
911

1012
// === Errors ===
1113

@@ -21,15 +23,17 @@ const ECannotBeConvertedToUD30x9: vector<u8> = "Value cannot be converted to UD3
2123
#[error(code = 2)]
2224
const EDivideByZero: vector<u8> = "Divisor must be non-zero";
2325

26+
/// Cannot compute square root of a negative value
27+
#[error(code = 3)]
28+
const ENegativeSqrt: vector<u8> = "Cannot compute square root of a negative value";
29+
2430
// === Structs ===
2531

2632
public struct Components has copy, drop {
2733
neg: bool,
2834
mag: u256,
2935
}
3036

31-
// === Public Functions ===
32-
3337
// === Conversion ===
3438

3539
/// Converts a `SD29x9` value to a `UD30x9` value.
@@ -64,6 +68,8 @@ public fun try_into_UD30x9(x: SD29x9): Option<UD30x9> {
6468
}
6569
}
6670

71+
// === Public Functions ===
72+
6773
/// Returns the absolute value of a `SD29x9`.
6874
///
6975
/// #### Parameters
@@ -469,6 +475,29 @@ public fun pow(x: SD29x9, exp: u8): SD29x9 {
469475
result.wrap_components()
470476
}
471477

478+
/// Computes the square root of a `SD29x9` value.
479+
///
480+
/// The result is the largest `SD29x9` value `r` such that `r * r <= x`. In other words, the
481+
/// result is truncated (rounded down) to the nearest representable `SD29x9` value.
482+
///
483+
/// #### Parameters
484+
/// - `x`: Input value.
485+
///
486+
/// #### Returns
487+
/// - The non-negative square root of `x`, rounded down to the nearest representable `SD29x9`
488+
/// value.
489+
///
490+
/// #### Aborts
491+
/// - Aborts if `x` is negative.
492+
public fun sqrt(x: SD29x9): SD29x9 {
493+
let Components { neg, mag } = decompose(x.unwrap());
494+
assert!(!neg, ENegativeSqrt);
495+
// Multiply by SCALE to preserve 9 decimal places of precision through the square root:
496+
// sqrt(mag / SCALE) = sqrt(mag * SCALE) / SCALE
497+
let result = u256::sqrt(mag * common::scale_u256!(), rounding::down());
498+
wrap_components(Components { neg: false, mag: result })
499+
}
500+
472501
/// Returns the arithmetic negation of `x`.
473502
///
474503
/// #### Parameters

math/fixed_point/sources/ud30x9/ud30x9.move

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public use fun openzeppelin_fp_math::ud30x9_base::mul as UD30x9.mul;
3737
public use fun openzeppelin_fp_math::ud30x9_base::mul_away as UD30x9.mul_away;
3838
public use fun openzeppelin_fp_math::ud30x9_base::mul_trunc as UD30x9.mul_trunc;
3939
public use fun openzeppelin_fp_math::ud30x9_base::pow as UD30x9.pow;
40+
public use fun openzeppelin_fp_math::ud30x9_base::sqrt as UD30x9.sqrt;
4041
public use fun openzeppelin_fp_math::ud30x9_base::sub as UD30x9.sub;
4142
public use fun openzeppelin_fp_math::ud30x9_base::unchecked_add as UD30x9.unchecked_add;
4243
public use fun openzeppelin_fp_math::ud30x9_base::unchecked_sub as UD30x9.unchecked_sub;

math/fixed_point/sources/ud30x9/ud30x9_base.move

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ module openzeppelin_fp_math::ud30x9_base;
33

44
use openzeppelin_fp_math::common;
55
use openzeppelin_fp_math::sd29x9::{Self, SD29x9};
6-
use openzeppelin_fp_math::ud30x9::{Self, UD30x9, wrap, one};
6+
use openzeppelin_fp_math::ud30x9::{UD30x9, wrap, zero, one};
7+
use openzeppelin_math::rounding;
8+
use openzeppelin_math::u256;
79

810
// === Errors ===
911

@@ -236,7 +238,7 @@ public fun lshift(x: UD30x9, bits: u8): UD30x9 {
236238
/// - Otherwise, the result of shifting the `x`'s raw bits left by `bits`.
237239
public fun unchecked_lshift(x: UD30x9, bits: u8): UD30x9 {
238240
if (bits >= 128) {
239-
return ud30x9::zero()
241+
return zero()
240242
};
241243
wrap(x.unwrap() << bits)
242244
}
@@ -454,6 +456,24 @@ public fun pow(x: UD30x9, exp: u8): UD30x9 {
454456
wrap_u256(result)
455457
}
456458

459+
/// Computes the square root of a `UD30x9` value.
460+
///
461+
/// The result is the largest `UD30x9` value `r` such that `r * r <= x`. In other words, the
462+
/// result is truncated (rounded down) to the nearest representable `UD30x9` value.
463+
///
464+
/// #### Parameters
465+
/// - `x`: Input value.
466+
///
467+
/// #### Returns
468+
/// - The square root of `x`, rounded down to the nearest representable `UD30x9` value.
469+
public fun sqrt(x: UD30x9): UD30x9 {
470+
let raw = x.unwrap() as u256;
471+
// Multiply by SCALE to preserve 9 decimal places of precision through the square root:
472+
// sqrt(raw / SCALE) = sqrt(raw * SCALE) / SCALE
473+
let result = u256::sqrt(raw * common::scale_u256!(), rounding::down());
474+
wrap(result as u128)
475+
}
476+
457477
/// Checks whether two `UD30x9` values are not equal.
458478
///
459479
/// #### Parameters
@@ -520,7 +540,7 @@ public fun rshift(x: UD30x9, bits: u8): UD30x9 {
520540
/// - Otherwise, the result of shifting the `x`'s raw bits right by `bits`.
521541
public fun unchecked_rshift(x: UD30x9, bits: u8): UD30x9 {
522542
if (bits >= 128) {
523-
return ud30x9::zero()
543+
return zero()
524544
};
525545
wrap(x.unwrap() >> bits)
526546
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#[test_only]
2+
module openzeppelin_fp_math::sd29x9_sqrt_tests;
3+
4+
use openzeppelin_fp_math::sd29x9;
5+
use openzeppelin_fp_math::sd29x9_base;
6+
use openzeppelin_fp_math::sd29x9_test_helpers::{pos, neg};
7+
use std::unit_test::assert_eq;
8+
9+
const SCALE: u128 = 1_000_000_000;
10+
const MAX_POSITIVE_VALUE: u128 = 0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF;
11+
12+
// ==== Tests ====
13+
14+
#[test]
15+
fun sqrt_of_zero_is_zero() {
16+
assert_eq!(sd29x9::zero().sqrt(), sd29x9::zero());
17+
}
18+
19+
#[test]
20+
fun sqrt_of_positive_one() {
21+
assert_eq!(sd29x9::one().sqrt(), sd29x9::one());
22+
}
23+
24+
#[test]
25+
fun sqrt_of_positive_perfect_squares() {
26+
// sqrt(+4.0) = +2.0
27+
assert_eq!(pos(4 * SCALE).sqrt(), pos(2 * SCALE));
28+
// sqrt(+9.0) = +3.0
29+
assert_eq!(pos(9 * SCALE).sqrt(), pos(3 * SCALE));
30+
// sqrt(+25.0) = +5.0
31+
assert_eq!(pos(25 * SCALE).sqrt(), pos(5 * SCALE));
32+
// sqrt(+100.0) = +10.0
33+
assert_eq!(pos(100 * SCALE).sqrt(), pos(10 * SCALE));
34+
// sqrt(+10000.0) = +100.0
35+
assert_eq!(pos(10_000 * SCALE).sqrt(), pos(100 * SCALE));
36+
}
37+
38+
#[test]
39+
fun sqrt_of_positive_fractional_squares() {
40+
// sqrt(+0.25) = +0.5
41+
assert_eq!(pos(250_000_000).sqrt(), pos(500_000_000));
42+
// sqrt(+0.01) = +0.1
43+
assert_eq!(pos(10_000_000).sqrt(), pos(100_000_000));
44+
// sqrt(+2.25) = +1.5
45+
assert_eq!(pos(2_250_000_000).sqrt(), pos(1_500_000_000));
46+
}
47+
48+
#[test]
49+
fun sqrt_truncates_irrational_results() {
50+
// sqrt(+2.0) = +1.414213562 (truncated)
51+
assert_eq!(pos(2 * SCALE).sqrt(), pos(1_414_213_562));
52+
// sqrt(+3.0) = +1.732050807 (truncated)
53+
assert_eq!(pos(3 * SCALE).sqrt(), pos(1_732_050_807));
54+
// sqrt(+5.0) = +2.236067977 (truncated)
55+
assert_eq!(pos(5 * SCALE).sqrt(), pos(2_236_067_977));
56+
}
57+
58+
#[test]
59+
fun sqrt_of_max_positive() {
60+
// sqrt(sd29x9::max()) should not abort and satisfy the floor property
61+
let result = sd29x9::max().sqrt();
62+
let r = result.unwrap() as u256;
63+
let max_scaled = (sd29x9::max().unwrap() as u256) * (SCALE as u256);
64+
assert!(r * r <= max_scaled);
65+
assert!((r + 1) * (r + 1) > max_scaled);
66+
}
67+
68+
#[random_test]
69+
fun sqrt_result_is_always_non_negative(raw: u128) {
70+
let raw = raw % (MAX_POSITIVE_VALUE + 1);
71+
let result = sd29x9::wrap(raw, false).sqrt();
72+
// Result is non-negative: raw bits should not have sign bit set
73+
assert!(result.unwrap() <= MAX_POSITIVE_VALUE);
74+
}
75+
76+
#[random_test]
77+
fun sqrt_floor_invariant(raw: u128) {
78+
let raw = raw % (MAX_POSITIVE_VALUE + 1);
79+
let result = sd29x9::wrap(raw, false).sqrt();
80+
// Floor property: r^2 <= x * SCALE < (r + 1)^2
81+
let r = result.unwrap() as u256;
82+
let scaled = (raw as u256) * (SCALE as u256);
83+
assert!(r * r <= scaled);
84+
assert!((r + 1) * (r + 1) > scaled);
85+
}
86+
87+
#[test]
88+
fun sqrt_squared_roundtrip_for_perfect_squares() {
89+
let values = vector[
90+
pos(4 * SCALE),
91+
pos(9 * SCALE),
92+
pos(25 * SCALE),
93+
pos(250_000_000), // 0.25
94+
];
95+
values.destroy!(|x| {
96+
let root = x.sqrt();
97+
assert_eq!(root.mul(root), x);
98+
});
99+
}
100+
101+
#[test, expected_failure(abort_code = sd29x9_base::ENegativeSqrt)]
102+
fun sqrt_of_negative_aborts() {
103+
neg(SCALE).sqrt();
104+
}
105+
106+
#[test, expected_failure(abort_code = sd29x9_base::ENegativeSqrt)]
107+
fun sqrt_of_small_negative_aborts() {
108+
neg(1).sqrt();
109+
}
110+
111+
#[test, expected_failure(abort_code = sd29x9_base::ENegativeSqrt)]
112+
fun sqrt_of_min_value_aborts() {
113+
sd29x9::min().sqrt();
114+
}

0 commit comments

Comments
 (0)