diff --git a/src/descriptor/mod.rs b/src/descriptor/mod.rs index e6e9e0b7f..9b83a227d 100644 --- a/src/descriptor/mod.rs +++ b/src/descriptor/mod.rs @@ -1671,7 +1671,7 @@ mod tests { #[test] fn tr_roundtrip_key() { let script = Tr::::from_str("tr()").unwrap().to_string(); - assert_eq!(script, format!("tr()#x4ml3kxd")) + assert_eq!(script, "tr()#x4ml3kxd".to_string()) } #[test] diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 6bfdfd055..84644915f 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -34,7 +34,9 @@ pub use self::error::{ParseNumError, ParseThresholdError, ParseTreeError}; use crate::blanket_traits::StaticDebugAndDisplay; use crate::descriptor::checksum::verify_checksum; use crate::prelude::*; -use crate::{AbsLockTime, Error, ParseError, RelLockTime, Threshold, MAX_RECURSION_DEPTH}; +use crate::{ + AbsLockTime, Error, ParseError, PositiveF64, RelLockTime, Threshold, MAX_RECURSION_DEPTH, +}; /// Allowed characters are descriptor strings. pub const INPUT_CHARSET: &str = "0123456789()[],'/*abcdefgh@:$%{}IJKLMNOPQRSTUVWXYZ&+-.;<=>?!^_|~ijklmnopqrstuvwxyzABCDEFGH`#\"\\ "; @@ -679,7 +681,7 @@ impl<'a> Tree<'a> { } /// Parse a string as a u32, forbidding zero. -pub fn parse_num_nonzero(s: &str, context: &'static str) -> Result { +fn parse_num_nonzero(s: &str, context: &'static str) -> Result { if s == "0" { return Err(ParseNumError::IllegalZero { context }); } @@ -691,6 +693,12 @@ pub fn parse_num_nonzero(s: &str, context: &'static str) -> Result Result { + let parsed = parse_num_nonzero(s, context)?; + PositiveF64::try_from(parsed).map_err(|_| ParseNumError::IllegalZero { context }) +} + /// Parse a string as a u32, for timelocks or thresholds pub fn parse_num(s: &str) -> Result { if s == "0" { diff --git a/src/lib.rs b/src/lib.rs index 2d9beaf78..ad6e5f7bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -140,6 +140,7 @@ pub use crate::miniscript::satisfy::{Preimage32, Satisfier}; pub use crate::miniscript::{hash256, Miniscript}; use crate::prelude::*; pub use crate::primitives::absolute_locktime::{AbsLockTime, AbsLockTimeError}; +pub use crate::primitives::positive_f64::PositiveF64; pub use crate::primitives::relative_locktime::{RelLockTime, RelLockTimeError}; pub use crate::primitives::threshold::{Threshold, ThresholdError}; pub use crate::validation::{Error as ValidationError, ValidationParams}; diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 4ffadb804..74359fa35 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -5,7 +5,7 @@ //! Optimizing compiler from concrete policies to Miniscript //! -use core::{cmp, f64, fmt, hash, mem}; +use core::{f64, fmt, mem}; #[cfg(feature = "std")] use std::error; @@ -16,27 +16,12 @@ use crate::miniscript::types::{self, ErrorKind, ExtData, Type}; use crate::miniscript::ScriptContext; use crate::policy::Concrete; use crate::prelude::*; -use crate::{policy, Miniscript, MiniscriptKey, Terminal}; +use crate::{policy, Miniscript, MiniscriptKey, PositiveF64, Terminal, Threshold}; -type PolicyCache = - BTreeMap<(Concrete, OrdF64, Option), BTreeMap>>; - -/// Ordered f64 for comparison. -#[derive(Copy, Clone, PartialEq, Debug)] -pub(crate) struct OrdF64(pub f64); - -impl Eq for OrdF64 {} -// We could derive PartialOrd, but we can't derive Ord, and clippy wants us -// to derive both or neither. Better to be explicit. -impl PartialOrd for OrdF64 { - fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } -} -impl Ord for OrdF64 { - fn cmp(&self, other: &Self) -> cmp::Ordering { - // will panic if given NaN - self.0.partial_cmp(&other.0).unwrap() - } -} +type PolicyCache = BTreeMap< + (Concrete, PositiveF64, Option), + BTreeMap>, +>; /// Detailed error type for compiler. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] @@ -128,11 +113,6 @@ impl From for CompilerError { fn from(e: policy::concrete::PolicyError) -> Self { Self::PolicyError(e) } } -/// Hash required for using OrdF64 as key for hashmap -impl hash::Hash for OrdF64 { - fn hash(&self, state: &mut H) { self.0.to_bits().hash(state); } -} - /// Compilation key: This represents the state of the best possible compilation /// of a given policy(implicitly keyed). #[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord, Hash)] @@ -150,7 +130,7 @@ struct CompilationKey { /// The probability of dissatisfaction of the compilation of the policy. Note /// that all possible compilations of a (sub)policy have the same sat-prob /// and only differ in dissat_prob. - dissat_prob: Option, + dissat_prob: Option, } impl CompilationKey { @@ -163,8 +143,8 @@ impl CompilationKey { } /// Helper to create compilation key from components - fn from_type(ty: Type, expensive_verify: bool, dissat_prob: Option) -> Self { - Self { ty, expensive_verify, dissat_prob: dissat_prob.map(OrdF64) } + fn from_type(ty: Type, expensive_verify: bool, dissat_prob: Option) -> Self { + Self { ty, expensive_verify, dissat_prob } } } @@ -386,14 +366,14 @@ impl CompilerExtData { } } - fn threshold(k: usize, n: usize, mut sub_ck: S) -> Self + fn threshold(thresh: &Threshold, mut sub_ck: S) -> Self where S: FnMut(usize) -> Self, { - let k_over_n = k as f64 / n as f64; + let k_over_n = thresh.k_over_n().value(); let mut sat_cost = 0.0; let mut dissat_cost = 0.0; - for i in 0..n { + for i in 0..thresh.n() { let sub = sub_ck(i); sat_cost += sub.sat_cost; dissat_cost += sub.dissat_cost.unwrap(); @@ -497,7 +477,7 @@ impl CompilerExtData { Self::and_or(atype, btype, ctype) } Terminal::Thresh(ref thresh) => { - Self::threshold(thresh.k(), thresh.n(), |n| get_child(&thresh.data()[n].node, n)) + Self::threshold(thresh, |n| get_child(&thresh.data()[n].node, n)) } } } @@ -516,15 +496,17 @@ impl AstElemExt { /// Compute a 1-dimensional cost, given a probability of satisfaction /// and a probability of dissatisfaction; if `dissat_prob` is `None` /// then it is assumed that dissatisfaction never occurs - fn cost_1d(&self, sat_prob: f64, dissat_prob: Option) -> f64 { - self.ms.ext.pk_cost as f64 - + self.comp_ext_data.sat_cost * sat_prob - + match (dissat_prob, self.comp_ext_data.dissat_cost) { - (Some(prob), Some(cost)) => prob * cost, - (Some(_), None) => f64::INFINITY, - (None, Some(_)) => 0.0, - (None, None) => 0.0, - } + fn cost_1d(&self, sat_prob: PositiveF64, dissat_prob: Option) -> PositiveF64 { + let sat_cost = (self.comp_ext_data.sat_cost > 0.0) + .then(|| PositiveF64::new(self.comp_ext_data.sat_cost) * sat_prob); + let base = PositiveF64::new(self.ms.ext.pk_cost as f64).conditional_add(sat_cost); + let dissat_cost = match (dissat_prob, self.comp_ext_data.dissat_cost) { + (Some(_), Some(0.0)) => None, + (Some(prob), Some(cost)) => Some(prob * PositiveF64::new(cost)), + (Some(_), None) => Some(PositiveF64::INFINITY), + (None, _) => None, + }; + base.conditional_add(dissat_cost) } } @@ -668,8 +650,8 @@ fn all_casts() -> [Cast; 10] { fn insert_elem( map: &mut BTreeMap>, elem: AstElemExt, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> bool { // We check before compiling that non-malleable satisfactions exist, and it appears that // there are no cases when malleable satisfactions beat non-malleable ones (and if there @@ -719,8 +701,8 @@ fn insert_elem( fn insert_elem_closure( map: &mut BTreeMap>, astelem_ext: AstElemExt, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) { let mut cast_stack: VecDeque> = VecDeque::new(); if insert_elem(map, astelem_ext.clone(), sat_prob, dissat_prob) { @@ -755,8 +737,8 @@ fn insert_best_wrapped( policy: &Concrete, map: &mut BTreeMap>, data: AstElemExt, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result<(), CompilerError> { insert_elem_closure(map, data, sat_prob, dissat_prob); @@ -779,17 +761,15 @@ fn insert_best_wrapped( fn best_compilations( policy_cache: &mut PolicyCache, policy: &Concrete, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result>, CompilerError> where Pk: MiniscriptKey, Ctx: ScriptContext, { //Check the cache for hits - let ord_sat_prob = OrdF64(sat_prob); - let ord_dissat_prob = dissat_prob.map(OrdF64); - if let Some(ret) = policy_cache.get(&(policy.clone(), ord_sat_prob, ord_dissat_prob)) { + if let Some(ret) = policy_cache.get(&(policy.clone(), sat_prob, dissat_prob)) { return Ok(ret.clone()); } @@ -807,8 +787,19 @@ where }; } macro_rules! compile_tern { - ($a:expr, $b:expr, $c: expr, $w: expr) => { - compile_tern(policy_cache, policy, &mut ret, $a, $b, $c, $w, sat_prob, dissat_prob)? + ($a:expr, $b:expr, $c: expr, [$ab_prob: expr, $c_prob: expr]) => { + compile_tern( + policy_cache, + policy, + &mut ret, + $a, + $b, + $c, + $ab_prob, + $c_prob, + sat_prob, + dissat_prob, + )? }; } @@ -849,22 +840,21 @@ where let mut q_zero_left = best_compilations(policy_cache, subs[0].as_ref(), sat_prob, None)?; - compile_binary!(&mut left, &mut right, [1.0, 1.0], Terminal::AndB); - compile_binary!(&mut right, &mut left, [1.0, 1.0], Terminal::AndB); - compile_binary!(&mut left, &mut right, [1.0, 1.0], Terminal::AndV); - compile_binary!(&mut right, &mut left, [1.0, 1.0], Terminal::AndV); + let one = PositiveF64::new(1.0); + compile_binary!(&mut left, &mut right, [one, one], Terminal::AndB); + compile_binary!(&mut right, &mut left, [one, one], Terminal::AndB); + compile_binary!(&mut left, &mut right, [one, one], Terminal::AndV); + compile_binary!(&mut right, &mut left, [one, one], Terminal::AndV); let mut zero_comp = BTreeMap::new(); zero_comp.insert( CompilationKey::from_type(Type::FALSE, ExtData::FALSE.has_free_verify, dissat_prob), AstElemExt::terminal(Miniscript::FALSE), ); - compile_tern!(&mut left, &mut q_zero_right, &mut zero_comp, [1.0, 0.0]); - compile_tern!(&mut right, &mut q_zero_left, &mut zero_comp, [1.0, 0.0]); + compile_tern!(&mut left, &mut q_zero_right, &mut zero_comp, [one, None]); + compile_tern!(&mut right, &mut q_zero_left, &mut zero_comp, [one, None]); } Concrete::Or(ref subs) => { - let total = (subs[0].0 + subs[1].0) as f64; - let lw = subs[0].0 as f64 / total; - let rw = subs[1].0 as f64 / total; + let (lw, rw) = PositiveF64::normalized(subs[0].0, subs[1].0); //and-or if let (Concrete::And(x), _) = (subs[0].1.as_ref(), subs[1].1.as_ref()) { @@ -872,7 +862,7 @@ where policy_cache, x[0].as_ref(), lw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), + Some((rw * sat_prob).conditional_add(dissat_prob)), )?; let mut a2 = best_compilations(policy_cache, x[0].as_ref(), lw * sat_prob, None)?; @@ -880,7 +870,7 @@ where policy_cache, x[1].as_ref(), lw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), + Some((rw * sat_prob).conditional_add(dissat_prob)), )?; let mut b2 = best_compilations(policy_cache, x[1].as_ref(), lw * sat_prob, None)?; @@ -891,15 +881,15 @@ where dissat_prob, )?; - compile_tern!(&mut a1, &mut b2, &mut c, [lw, rw]); - compile_tern!(&mut b1, &mut a2, &mut c, [lw, rw]); + compile_tern!(&mut a1, &mut b2, &mut c, [lw, Some(rw)]); + compile_tern!(&mut b1, &mut a2, &mut c, [lw, Some(rw)]); }; if let (_, Concrete::And(x)) = (&subs[0].1.as_ref(), subs[1].1.as_ref()) { let mut a1 = best_compilations( policy_cache, x[0].as_ref(), rw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), + Some((lw * sat_prob).conditional_add(dissat_prob)), )?; let mut a2 = best_compilations(policy_cache, x[0].as_ref(), rw * sat_prob, None)?; @@ -907,7 +897,7 @@ where policy_cache, x[1].as_ref(), rw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), + Some((lw * sat_prob).conditional_add(dissat_prob)), )?; let mut b2 = best_compilations(policy_cache, x[1].as_ref(), rw * sat_prob, None)?; @@ -918,13 +908,13 @@ where dissat_prob, )?; - compile_tern!(&mut a1, &mut b2, &mut c, [rw, lw]); - compile_tern!(&mut b1, &mut a2, &mut c, [rw, lw]); + compile_tern!(&mut a1, &mut b2, &mut c, [rw, Some(lw)]); + compile_tern!(&mut b1, &mut a2, &mut c, [rw, Some(lw)]); }; - let dissat_probs = |w: f64| -> Vec> { + let dissat_probs = |w: PositiveF64| -> Vec> { vec![ - Some(dissat_prob.unwrap_or(0 as f64) + w * sat_prob), + Some((w * sat_prob).conditional_add(dissat_prob)), Some(w * sat_prob), dissat_prob, None, @@ -973,7 +963,16 @@ where Concrete::Thresh(ref thresh) => { let k = thresh.k(); let n = thresh.n(); - let k_over_n = k as f64 / n as f64; + + let (sat_ratio, dissat_ratio) = if n > k { + let (s, d) = PositiveF64::normalized( + PositiveF64::new(k as f64), + PositiveF64::new((n - k) as f64), + ); + (s, Some(d)) + } else { + (PositiveF64::new(1.0), None) + }; let mut sub_ext_data = Vec::with_capacity(n); @@ -982,13 +981,15 @@ where let mut min_value = (0, f64::INFINITY); for (i, ast) in thresh.iter().enumerate() { - let sp = sat_prob * k_over_n; + let sp = sat_prob * sat_ratio; //Expressions must be dissatisfiable - let dp = Some(dissat_prob.unwrap_or(0 as f64) + (1.0 - k_over_n) * sat_prob); + let dp = dissat_ratio + .map(|ratio| (ratio * sat_prob).conditional_add(dissat_prob)) + .or(dissat_prob); let be = best(types::Base::B, policy_cache, ast.as_ref(), sp, dp)?; let bw = best(types::Base::W, policy_cache, ast.as_ref(), sp, dp)?; - let diff = be.cost_1d(sp, dp) - bw.cost_1d(sp, dp); + let diff = be.cost_1d(sp, dp).value() - bw.cost_1d(sp, dp).value(); best_es.push((be.comp_ext_data, be)); best_ws.push((bw.comp_ext_data, bw)); @@ -1023,7 +1024,7 @@ where if let Ok(ms) = Miniscript::from_ast(ast) { let ast_ext = AstElemExt { ms: Arc::new(ms), - comp_ext_data: CompilerExtData::threshold(k, n, |i| sub_ext_data[i]), + comp_ext_data: CompilerExtData::threshold(thresh, |i| sub_ext_data[i]), }; insert_wrap!(ast_ext); } @@ -1065,7 +1066,7 @@ where } } for k in ret.keys() { - debug_assert_eq!(k.dissat_prob, ord_dissat_prob); + debug_assert_eq!(k.dissat_prob, dissat_prob); } if ret.is_empty() { // The only reason we are discarding elements out of compiler is because @@ -1076,7 +1077,7 @@ where // before calling this compile function Err(CompilerError::LimitsExceeded) } else { - policy_cache.insert((policy.clone(), ord_sat_prob, ord_dissat_prob), ret.clone()); + policy_cache.insert((policy.clone(), sat_prob, dissat_prob), ret.clone()); Ok(ret) } } @@ -1091,9 +1092,9 @@ fn compile_binary( ret: &mut BTreeMap>, left_comp: &mut BTreeMap>, right_comp: &mut BTreeMap>, - weights: [f64; 2], - sat_prob: f64, - dissat_prob: Option, + weights: [PositiveF64; 2], + sat_prob: PositiveF64, + dissat_prob: Option, bin_func: F, ) -> Result<(), CompilerError> where @@ -1106,8 +1107,8 @@ where for r in right_comp.values_mut() { let rref = Arc::clone(&r.ms); let ast = bin_func(Arc::clone(&lref), Arc::clone(&rref)); - l.comp_ext_data.branch_prob = Some(weights[0]); - r.comp_ext_data.branch_prob = Some(weights[1]); + l.comp_ext_data.branch_prob = Some(weights[0].value()); + r.comp_ext_data.branch_prob = Some(weights[1].value()); if let Ok(new_ext) = AstElemExt::binary(ast, l, r) { insert_best_wrapped(policy_cache, policy, ret, new_ext, sat_prob, dissat_prob)?; } @@ -1127,10 +1128,13 @@ fn compile_tern( a_comp: &mut BTreeMap>, b_comp: &mut BTreeMap>, c_comp: &mut BTreeMap>, - weights: [f64; 2], - sat_prob: f64, - dissat_prob: Option, + ab_branch_prob: PositiveF64, + c_branch_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result<(), CompilerError> { + let ab_prob = ab_branch_prob.value(); + let c_prob = c_branch_prob.map_or(0.0, |p| p.value()); for a in a_comp.values_mut() { let aref = Arc::clone(&a.ms); for b in b_comp.values_mut() { @@ -1138,9 +1142,9 @@ fn compile_tern( for c in c_comp.values_mut() { let cref = Arc::clone(&c.ms); let ast = Terminal::AndOr(Arc::clone(&aref), Arc::clone(&bref), Arc::clone(&cref)); - a.comp_ext_data.branch_prob = Some(weights[0]); - b.comp_ext_data.branch_prob = Some(weights[0]); - c.comp_ext_data.branch_prob = Some(weights[1]); + a.comp_ext_data.branch_prob = Some(ab_prob); + b.comp_ext_data.branch_prob = Some(ab_prob); + c.comp_ext_data.branch_prob = Some(c_prob); if let Ok(new_ext) = AstElemExt::ternary(ast, a, b, c) { insert_best_wrapped(policy_cache, policy, ret, new_ext, sat_prob, dissat_prob)?; } @@ -1155,7 +1159,7 @@ pub fn best_compilation( policy: &Concrete, ) -> Result, CompilerError> { let mut policy_cache = PolicyCache::::new(); - let x = &*best_t(&mut policy_cache, policy, 1.0, None)?.ms; + let x = &*best_t(&mut policy_cache, policy, PositiveF64::new(1.0), None)?.ms; if !x.ty.mall.signed { Err(CompilerError::TopLevelSigless) } else if !x.ty.mall.non_malleable { @@ -1169,8 +1173,8 @@ pub fn best_compilation( fn best_t( policy_cache: &mut PolicyCache, policy: &Concrete, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result, CompilerError> where Pk: MiniscriptKey, @@ -1178,11 +1182,9 @@ where { best_compilations(policy_cache, policy, sat_prob, dissat_prob)? .into_iter() - .filter(|&(key, _)| { - key.ty.corr.base == types::Base::B && key.dissat_prob == dissat_prob.map(OrdF64) - }) + .filter(|&(key, _)| key.ty.corr.base == types::Base::B && key.dissat_prob == dissat_prob) .map(|(_, val)| val) - .min_by_key(|ext| OrdF64(ext.cost_1d(sat_prob, dissat_prob))) + .min_by_key(|ext| ext.cost_1d(sat_prob, dissat_prob)) .ok_or(CompilerError::LimitsExceeded) } @@ -1191,8 +1193,8 @@ fn best( basic_type: types::Base, policy_cache: &mut PolicyCache, policy: &Concrete, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result, CompilerError> where Pk: MiniscriptKey, @@ -1204,10 +1206,10 @@ where key.ty.corr.base == basic_type && key.ty.corr.unit && val.ms.ty.mall.dissat == types::Dissat::Unique - && key.dissat_prob == dissat_prob.map(OrdF64) + && key.dissat_prob == dissat_prob }) .map(|(_, val)| val) - .min_by_key(|ext| OrdF64(ext.cost_1d(sat_prob, dissat_prob))) + .min_by_key(|ext| ext.cost_1d(sat_prob, dissat_prob)) .ok_or(CompilerError::LimitsExceeded) } @@ -1309,18 +1311,23 @@ mod tests { #[test] fn compile_q() { let policy = SPolicy::from_str("or(1@and(pk(A),pk(B)),127@pk(C))").expect("parsing"); - let compilation: TapAstElemExt = best_t(&mut BTreeMap::new(), &policy, 1.0, None).unwrap(); + let compilation: TapAstElemExt = + best_t(&mut BTreeMap::new(), &policy, PositiveF64::new(1.0), None).unwrap(); - assert_eq!(compilation.cost_1d(1.0, None), 87.0 + 67.0390625); + assert_eq!(compilation.cost_1d(PositiveF64::new(1.0), None).value(), 87.0 + 67.0390625); assert_eq!(policy.lift().unwrap().sorted(), compilation.ms.lift().unwrap().sorted()); // compile into taproot context to avoid limit errors let policy = SPolicy::from_str( "and(and(and(or(127@thresh(2,pk(A),pk(B),thresh(2,or(127@pk(A),1@pk(B)),after(100),or(and(pk(C),after(200)),and(pk(D),sha256(66687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925))),pk(E))),1@pk(F)),sha256(66687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925)),or(127@pk(G),1@after(300))),or(127@after(400),pk(H)))" ).expect("parsing"); - let compilation: TapAstElemExt = best_t(&mut BTreeMap::new(), &policy, 1.0, None).unwrap(); + let compilation: TapAstElemExt = + best_t(&mut BTreeMap::new(), &policy, PositiveF64::new(1.0), None).unwrap(); - assert_eq!(compilation.cost_1d(1.0, None), 433.0 + 275.7909749348958); + assert_eq!( + compilation.cost_1d(PositiveF64::new(1.0), None).value(), + 433.0 + 275.7909749348958 + ); assert_eq!(policy.lift().unwrap().sorted(), compilation.ms.lift().unwrap().sorted()); } @@ -1365,14 +1372,14 @@ mod tests { // Liquid policy let policy: BPolicy = Concrete::Or(vec![ ( - 127, + PositiveF64::new(127.0), Arc::new(Concrete::Thresh( Threshold::from_iter(3, key_pol[0..5].iter().map(|p| (p.clone()).into())) .unwrap(), )), ), ( - 1, + PositiveF64::new(1.0), Arc::new(Concrete::And(vec![ Arc::new(Concrete::Older(RelLockTime::from_height(10000).unwrap())), Arc::new(Concrete::Thresh( @@ -1541,8 +1548,8 @@ mod tests { .collect(); let thresh_res: Result = Concrete::Or(vec![ - (1, Arc::new(Concrete::Thresh(Threshold::and_n(keys_a)))), - (1, Arc::new(Concrete::Thresh(Threshold::and_n(keys_b)))), + (PositiveF64::new(1.0), Arc::new(Concrete::Thresh(Threshold::and_n(keys_a)))), + (PositiveF64::new(1.0), Arc::new(Concrete::Thresh(Threshold::and_n(keys_b)))), ]) .compile(); let script_size = thresh_res.clone().map(|m| m.script_size()); @@ -1612,14 +1619,21 @@ mod tests { // Test that we refuse to compile policies with duplicated keys let (keys, _) = pubkeys_and_a_sig(1); let key = Arc::new(Concrete::Key(keys[0])); - let res = - Concrete::Or(vec![(1, Arc::clone(&key)), (1, Arc::clone(&key))]).compile::(); + let res = Concrete::Or(vec![ + (PositiveF64::new(1.0), Arc::clone(&key)), + (PositiveF64::new(1.0), Arc::clone(&key)), + ]) + .compile::(); assert_eq!( res, Err(CompilerError::PolicyError(policy::concrete::PolicyError::DuplicatePubKeys)) ); // Same for legacy - let res = Concrete::Or(vec![(1, key.clone()), (1, key)]).compile::(); + let res = Concrete::Or(vec![ + (PositiveF64::new(1.0), key.clone()), + (PositiveF64::new(1.0), key), + ]) + .compile::(); assert_eq!( res, Err(CompilerError::PolicyError(policy::concrete::PolicyError::DuplicatePubKeys)) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index b7c9045bb..39703154a 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -12,7 +12,7 @@ use bitcoin::absolute; use { crate::descriptor::TapTree, crate::miniscript::ScriptContext, - crate::policy::compiler::{self, CompilerError, OrdF64}, + crate::policy::compiler::{self, CompilerError}, crate::Descriptor, crate::Miniscript, crate::Tap, @@ -28,7 +28,8 @@ use crate::sync::Arc; #[cfg(all(doc, not(feature = "compiler")))] use crate::Descriptor; use crate::{ - AbsLockTime, Error, ForEachKey, FromStrKey, MiniscriptKey, RelLockTime, Threshold, Translator, + AbsLockTime, Error, ForEachKey, FromStrKey, MiniscriptKey, PositiveF64, RelLockTime, Threshold, + Translator, }; /// Maximum `TapLeaf`s allowed in a compiled TapTree @@ -65,7 +66,7 @@ pub enum Policy { And(Vec>), /// A list of sub-policies, one of which must be satisfied, along with /// relative probabilities for each one. - Or(Vec<(usize, Arc)>), + Or(Vec<(PositiveF64, Arc)>), /// A set of descriptors, satisfactions must be provided for `k` of them. Thresh(Threshold, 0>), } @@ -131,9 +132,10 @@ impl<'p, Pk: MiniscriptKey> Iterator for TapleafProbabilityIter<'p, Pk> { match top { Policy::Or(ref subs) => { - let total_sub_prob = subs.iter().map(|prob_sub| prob_sub.0).sum::(); + let total_sub_prob = + subs.iter().map(|prob_sub| prob_sub.0.value()).sum::(); for (sub_prob, sub) in subs.iter().rev() { - let ratio = *sub_prob as f64 / total_sub_prob as f64; + let ratio = sub_prob.value() / total_sub_prob; self.stack.push((top_prob * ratio, sub)); } } @@ -199,7 +201,7 @@ impl Policy { let internal_key = self .tapleaf_probability_iter() .filter_map(|(prob, ref pol)| match pol { - Self::Key(pk) => Some((OrdF64(prob), pk)), + Self::Key(pk) => Some((PositiveF64::new(prob), pk)), _ => None, }) .max_by_key(|(prob, _)| *prob) @@ -244,14 +246,15 @@ impl Policy { match policy { Self::Trivial => None, policy => { - let mut leaf_compilations: Vec<(OrdF64, Miniscript)> = vec![]; + let mut leaf_compilations: Vec<(PositiveF64, Miniscript)> = + vec![]; for (prob, pol) in policy.tapleaf_probability_iter() { // policy corresponding to the key (replaced by unsatisfiable) is skipped if *pol == Self::Unsatisfiable { continue; } let compilation = compiler::best_compilation::(pol)?; - leaf_compilations.push((OrdF64(prob), compilation)); + leaf_compilations.push((PositiveF64::new(prob), compilation)); } if !leaf_compilations.is_empty() { let tap_tree = with_huffman_tree::(leaf_compilations); @@ -303,7 +306,7 @@ impl Policy { if n > max_leaves { return Err(CompilerError::TooManyTapleaves { n, max: max_leaves }); } - let mut leaf_compilations: Vec<(OrdF64, Miniscript)> = vec![]; + let mut leaf_compilations: Vec<(PositiveF64, Miniscript)> = vec![]; for (leaf_idx, (prob, pol)) in leaves.iter().enumerate() { if **pol == Self::Unsatisfiable { continue; @@ -314,7 +317,7 @@ impl Policy { leaf_index: leaf_idx, }); } - leaf_compilations.push((OrdF64(*prob), compilation)); + leaf_compilations.push((PositiveF64::new(*prob), compilation)); } if !leaf_compilations.is_empty() { Some(with_huffman_tree::(leaf_compilations)) @@ -371,7 +374,7 @@ impl Policy { .filter(|x| x.1 != Arc::new(Self::Unsatisfiable)) .map(|(prob, pol)| { ( - OrdF64(prob), + PositiveF64::new(prob), compiler::best_compilation(pol.as_ref()).unwrap(), ) }) @@ -454,9 +457,9 @@ impl Policy { fn enumerate_pol(&self, prob: f64) -> Vec<(f64, Arc)> { match self { Self::Or(subs) => { - let total_odds = subs.iter().fold(0, |acc, x| acc + x.0); + let total_odds = subs.iter().fold(0.0, |acc, x| acc + x.0.value()); subs.iter() - .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) + .map(|(odds, pol)| (prob * odds.value() / total_odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { @@ -479,9 +482,9 @@ impl Policy { fn enumerate_pol_native(&self, prob: f64) -> Vec<(f64, Arc)> { match self { Self::Or(subs) => { - let total_odds = subs.iter().fold(0, |acc, x| acc + x.0); + let total_odds = subs.iter().fold(0.0, |acc, x| acc + x.0.value()); subs.iter() - .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) + .map(|(odds, pol)| (prob * odds.value() / total_odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { @@ -541,16 +544,16 @@ impl Policy { max_leaves: usize, expand_fn: fn(&Self, f64) -> Vec<(f64, Arc)>, ) -> Vec<(f64, Arc)> { - let mut tapleaf_prob_vec = BTreeSet::<(Reverse, Arc)>::new(); + let mut tapleaf_prob_vec = BTreeSet::<(Reverse, Arc)>::new(); // Store probability corresponding to policy in the enumerated tree. This is required since // owing to the current [policy element enumeration algorithm][`Policy::enumerate_pol`], // two passes of the algorithm might result in same sub-policy showing up. Currently, we // merge the nodes by adding up the corresponding probabilities for the same policy. - let mut pol_prob_map = BTreeMap::, OrdF64>::new(); + let mut pol_prob_map = BTreeMap::, PositiveF64>::new(); let arc_self = Arc::new(self); - tapleaf_prob_vec.insert((Reverse(OrdF64(prob)), Arc::clone(&arc_self))); - pol_prob_map.insert(Arc::clone(&arc_self), OrdF64(prob)); + tapleaf_prob_vec.insert((Reverse(PositiveF64::new(prob)), Arc::clone(&arc_self))); + pol_prob_map.insert(Arc::clone(&arc_self), PositiveF64::new(prob)); // Since we know that policy enumeration *must* result in increase in total number of nodes, // we can maintain the length of the ordered set to check if the @@ -565,7 +568,7 @@ impl Policy { // Stopping condition: When NONE of the inputs can be further enumerated. 'outer: loop { //--- FIND a plausible node --- - let mut prob: Reverse = Reverse(OrdF64(0.0)); + let mut prob: Option> = None; let mut curr_policy: Arc = Arc::new(Self::Unsatisfiable); let mut curr_pol_replace_vec: Vec<(f64, Arc)> = vec![]; let mut no_more_enum = false; @@ -574,13 +577,13 @@ impl Policy { // from the ordered set. let mut to_del: Vec<(f64, Arc)> = vec![]; 'inner: for (i, (p, pol)) in tapleaf_prob_vec.iter().enumerate() { - curr_pol_replace_vec = expand_fn(pol, p.0 .0); + curr_pol_replace_vec = expand_fn(pol, p.0.value()); enum_len += curr_pol_replace_vec.len() - 1; // A disjunctive node should have separated this into more nodes assert!(prev_len <= enum_len); if prev_len < enum_len { // Plausible node found - prob = *p; + prob = Some(*p); curr_policy = Arc::clone(pol); break 'inner; } else if i == tapleaf_prob_vec.len() - 1 { @@ -591,14 +594,14 @@ impl Policy { // Either node is enumerable, or we have // Mark all non-enumerable nodes to remove, // if not returning value in the current iteration. - to_del.push((p.0 .0, Arc::clone(pol))); + to_del.push((p.0.value(), Arc::clone(pol))); } } // --- Sanity Checks --- if enum_len > max_leaves || no_more_enum { for (p, pol) in tapleaf_prob_vec.into_iter() { - ret.push((p.0 .0, pol)); + ret.push((p.0.value(), pol)); } break 'outer; } @@ -607,12 +610,13 @@ impl Policy { // with children nodes // Remove current node - assert!(tapleaf_prob_vec.remove(&(prob, curr_policy.clone()))); + assert!(tapleaf_prob_vec + .remove(&(prob.expect("a plausible node was found"), curr_policy.clone()))); pol_prob_map.remove(&curr_policy); // OPTIMIZATION - Move marked nodes into final vector for (p, pol) in to_del { - assert!(tapleaf_prob_vec.remove(&(Reverse(OrdF64(p)), pol.clone()))); + assert!(tapleaf_prob_vec.remove(&(Reverse(PositiveF64::new(p)), pol.clone()))); pol_prob_map.remove(&pol); ret.push((p, pol.clone())); } @@ -622,12 +626,16 @@ impl Policy { match pol_prob_map.get(&policy) { Some(prev_prob) => { assert!(tapleaf_prob_vec.remove(&(Reverse(*prev_prob), policy.clone()))); - tapleaf_prob_vec.insert((Reverse(OrdF64(prev_prob.0 + p)), policy.clone())); - pol_prob_map.insert(policy.clone(), OrdF64(prev_prob.0 + p)); + tapleaf_prob_vec.insert(( + Reverse(PositiveF64::new(prev_prob.value() + p)), + policy.clone(), + )); + pol_prob_map + .insert(policy.clone(), PositiveF64::new(prev_prob.value() + p)); } None => { - tapleaf_prob_vec.insert((Reverse(OrdF64(p)), policy.clone())); - pol_prob_map.insert(policy.clone(), OrdF64(p)); + tapleaf_prob_vec.insert((Reverse(PositiveF64::new(p)), policy.clone())); + pol_prob_map.insert(policy.clone(), PositiveF64::new(p)); } } } @@ -966,7 +974,7 @@ impl expression::FromTree for Policy { .map_err(From::from) .map_err(Error::Parse)?; - let mut stack = Vec::<(usize, _)>::with_capacity(128); + let mut stack = Vec::<(PositiveF64, _)>::with_capacity(128); for node in root.pre_order_iter().rev() { let allow_prob; // Before doing anything else, check if this is the inner value of a terminal. @@ -999,10 +1007,10 @@ impl expression::FromTree for Policy { }; let frag_prob = match frag_prob { - None => 1, - Some(s) => expression::parse_num_nonzero(s, "fragment probability") + None => PositiveF64::new(1.0), + Some(s) => expression::parse_probability(s, "fragment probability") .map_err(From::from) - .map_err(Error::Parse)? as usize, + .map_err(Error::Parse)?, }; let new = @@ -1090,8 +1098,10 @@ fn has_if_fragment(ms: &Miniscript) -> bool { /// Creates a Huffman Tree from compiled [`Miniscript`] nodes. #[cfg(feature = "compiler")] -fn with_huffman_tree(ms: Vec<(OrdF64, Miniscript)>) -> TapTree { - let mut node_weights = BinaryHeap::<(Reverse, TapTree)>::new(); +fn with_huffman_tree( + ms: Vec<(PositiveF64, Miniscript)>, +) -> TapTree { + let mut node_weights = BinaryHeap::<(Reverse, TapTree)>::new(); for (prob, script) in ms { node_weights.push((Reverse(prob), TapTree::leaf(script))); } @@ -1100,9 +1110,9 @@ fn with_huffman_tree(ms: Vec<(OrdF64, Miniscript)>) let (p1, s1) = node_weights.pop().expect("len must at least be two"); let (p2, s2) = node_weights.pop().expect("len must at least be two"); - let p = (p1.0).0 + (p2.0).0; + let p = (p1.0).value() + (p2.0).value(); node_weights.push(( - Reverse(OrdF64(p)), + Reverse(PositiveF64::new(p)), TapTree::combine(s1, s2) .expect("huffman tree cannot produce depth > 128 given sane weights"), )); @@ -1155,7 +1165,7 @@ pub enum TreeChildren<'a, Pk: MiniscriptKey> { /// A conjunction or threshold node's children. And(&'a [Arc>]), /// A disjunction node's children. - Or(&'a [(usize, Arc>)]), + Or(&'a [(PositiveF64, Arc>)]), } impl<'a, Pk: MiniscriptKey> TreeLike for &'a Policy { diff --git a/src/primitives/mod.rs b/src/primitives/mod.rs index 919e6c8ba..d7afefdcb 100644 --- a/src/primitives/mod.rs +++ b/src/primitives/mod.rs @@ -12,5 +12,6 @@ //! should be re-exported at the crate root. pub mod absolute_locktime; +pub mod positive_f64; pub mod relative_locktime; pub mod threshold; diff --git a/src/primitives/positive_f64.rs b/src/primitives/positive_f64.rs new file mode 100644 index 000000000..46acbacca --- /dev/null +++ b/src/primitives/positive_f64.rs @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Positive Floating-Point Numbers +//! +//! A wrapper around [`f64`] used to represent branch probabilities and +//! execution costs within the policy compiler. +//! +//! Values are guaranteed to be positive floats and never NaN (but may be +//! infinite), ensuring that the type can safely implement [`Ord`] and [`Eq`] +//! without panicking. + +use core::{cmp, f64, fmt, hash, ops}; + +/// A positive, possibly-infinite, floating-point number. +#[derive(Copy, Clone, PartialEq, Debug)] +pub struct PositiveF64(pub(super) f64); + +impl PositiveF64 { + /// Positive Infinity, used as a sentinel for impossible compilation branches. + pub const INFINITY: Self = Self(f64::INFINITY); + + /// Creates a new `PositiveF64` from the given value. + /// + /// # Panics + /// + /// Panics if `value` is not a positive-floating point number. + pub fn new(value: f64) -> Self { + assert!(value > 0.0, "PositiveF64 must be positive and not NaN, got {value}"); + Self(value) + } + + /// Returns the `PositiveF64` value. + pub fn value(&self) -> f64 { self.0 } + + /// Normalizes two [`PositiveF64`] values into a valid probability distribution. + #[must_use] + pub fn normalized(a: Self, b: Self) -> (Self, Self) { + let sum = a.0 + b.0; + (Self(a.0 / sum), Self(b.0 / sum)) + } + + /// Adds an optional value to `self`, returning `self` unchanged if the value is + /// `None`. + #[must_use] + pub fn conditional_add(self, other: Option) -> Self { other.map_or(self, |i| self + i) } +} + +impl TryFrom for PositiveF64 { + type Error = NonZeroExpected; + + fn try_from(value: u32) -> Result { + if value == 0 { + Err(NonZeroExpected) + } else { + Ok(Self(value as f64)) + } + } +} + +impl Eq for PositiveF64 {} + +// We could derive PartialOrd, but we can't derive Ord, and clippy wants us +// to derive both or neither. Better to be explicit. +impl PartialOrd for PositiveF64 { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } +} + +impl Ord for PositiveF64 { + fn cmp(&self, other: &Self) -> cmp::Ordering { + // will panic if given NaN + self.0.partial_cmp(&other.0).unwrap() + } +} + +/// Hash required for using OrdF64 as key for hashmap +impl hash::Hash for PositiveF64 { + fn hash(&self, state: &mut H) { self.0.to_bits().hash(state); } +} + +impl fmt::Display for PositiveF64 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.0, f) } +} + +impl ops::Add for PositiveF64 { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { Self(self.0 + rhs.0) } +} + +impl ops::Mul for PositiveF64 { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { Self(self.0 * rhs.0) } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub struct NonZeroExpected; + +impl fmt::Display for NonZeroExpected { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "value must be non-zero") } +} + +#[cfg(feature = "std")] +impl std::error::Error for NonZeroExpected {} diff --git a/src/primitives/threshold.rs b/src/primitives/threshold.rs index a29f89e30..5d552f447 100644 --- a/src/primitives/threshold.rs +++ b/src/primitives/threshold.rs @@ -10,6 +10,7 @@ use core::{cmp, fmt, iter}; #[cfg(any(feature = "std", test))] use std::vec; +use super::positive_f64::PositiveF64; use crate::ToPublicKey; /// Error parsing an absolute locktime. @@ -203,6 +204,9 @@ impl Threshold { /// Accessor for the threshold value. pub const fn k(&self) -> usize { self.k } + /// Returns the threshold ratio (k / n) as a positive real number. + pub fn k_over_n(&self) -> PositiveF64 { PositiveF64(self.k() as f64 / self.n() as f64) } + /// Accessor for the underlying data. pub fn data(&self) -> &[T] { &self.inner }