diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 6bfdfd055..21a479526 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -27,6 +27,7 @@ mod error; +use core::num::NonZeroU32; use core::ops; use core::str::FromStr; @@ -679,7 +680,7 @@ impl<'a> Tree<'a> { } /// Parse a string as a u32, forbidding zero. -pub fn parse_num_nonzero(s: &str, context: &'static str) -> Result { +pub fn parse_num_nonzero(s: &str, context: &'static str) -> Result { if s == "0" { return Err(ParseNumError::IllegalZero { context }); } @@ -688,7 +689,7 @@ pub fn parse_num_nonzero(s: &str, context: &'static str) -> Result Result { // Special-case 0 since it is the only number which may start with a leading zero. return Ok(0); } - parse_num_nonzero(s, "") + parse_num_nonzero(s, "").map(u32::from) } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 2d9beaf78..ad6e5f7bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -140,6 +140,7 @@ pub use crate::miniscript::satisfy::{Preimage32, Satisfier}; pub use crate::miniscript::{hash256, Miniscript}; use crate::prelude::*; pub use crate::primitives::absolute_locktime::{AbsLockTime, AbsLockTimeError}; +pub use crate::primitives::positive_f64::PositiveF64; pub use crate::primitives::relative_locktime::{RelLockTime, RelLockTimeError}; pub use crate::primitives::threshold::{Threshold, ThresholdError}; pub use crate::validation::{Error as ValidationError, ValidationParams}; diff --git a/src/policy/compiler.rs b/src/policy/compiler.rs index 4ffadb804..ca9338464 100644 --- a/src/policy/compiler.rs +++ b/src/policy/compiler.rs @@ -5,39 +5,25 @@ //! Optimizing compiler from concrete policies to Miniscript //! -use core::{cmp, f64, fmt, hash, mem}; +use core::num::NonZeroU32; +use core::{f64, fmt, mem}; #[cfg(feature = "std")] use std::error; use sync::Arc; use crate::miniscript::context::SigType; -use crate::miniscript::types::{self, ErrorKind, ExtData, Type}; +use crate::miniscript::limits::{MAX_PUBKEYS_IN_CHECKSIGADD, MAX_PUBKEYS_PER_MULTISIG}; +use crate::miniscript::types::{self, ErrorKind, Type}; use crate::miniscript::ScriptContext; use crate::policy::Concrete; use crate::prelude::*; -use crate::{policy, Miniscript, MiniscriptKey, Terminal}; - -type PolicyCache = - BTreeMap<(Concrete, OrdF64, Option), BTreeMap>>; - -/// Ordered f64 for comparison. -#[derive(Copy, Clone, PartialEq, Debug)] -pub(crate) struct OrdF64(pub f64); - -impl Eq for OrdF64 {} -// We could derive PartialOrd, but we can't derive Ord, and clippy wants us -// to derive both or neither. Better to be explicit. -impl PartialOrd for OrdF64 { - fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } -} -impl Ord for OrdF64 { - fn cmp(&self, other: &Self) -> cmp::Ordering { - // will panic if given NaN - self.0.partial_cmp(&other.0).unwrap() - } -} +use crate::{policy, Miniscript, MiniscriptKey, PositiveF64, Terminal}; +type PolicyCache = BTreeMap< + (Concrete, PositiveF64, Option), + BTreeMap>, +>; /// Detailed error type for compiler. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] pub enum CompilerError { @@ -104,6 +90,142 @@ impl fmt::Display for CompilerError { } } +fn best_compilations_or( + ret: &mut BTreeMap>, + policy_cache: &mut PolicyCache, + policy: &Concrete, + subs: &[(NonZeroU32, Arc>)], + sat_prob: PositiveF64, + dissat_prob: Option, +) -> Result<(), CompilerError> { + let total = PositiveF64::from(subs[0].0) + PositiveF64::from(subs[1].0); + let lw = PositiveF64::from(subs[0].0) / total; + let rw = PositiveF64::from(subs[1].0) / total; + + //and-or + let mut insert_ternary = |policy_cache: &mut _, + a: &BTreeMap<_, _>, + b: &BTreeMap<_, _>, + c: &BTreeMap<_, _>, + lw: PositiveF64, + rw: PositiveF64| + -> Result<(), CompilerError> { + for a in a.values() { + for b in b.values() { + for c in c.values() { + if let Ok(new_ext) = AstElemExt::and_or(a, b, c, lw, rw) { + insert_best_wrapped( + policy_cache, + policy, + ret, + new_ext, + sat_prob, + dissat_prob, + )?; + } + } + } + } + Ok(()) + }; + + if let (Concrete::And(x), _) = (subs[0].1.as_ref(), subs[1].1.as_ref()) { + let a1 = best_compilations( + policy_cache, + x[0].as_ref(), + lw * sat_prob, + Some((rw * sat_prob).conditional_add(dissat_prob)), + )?; + let a2 = best_compilations(policy_cache, x[0].as_ref(), lw * sat_prob, None)?; + + let b1 = best_compilations( + policy_cache, + x[1].as_ref(), + lw * sat_prob, + Some((rw * sat_prob).conditional_add(dissat_prob)), + )?; + let b2 = best_compilations(policy_cache, x[1].as_ref(), lw * sat_prob, None)?; + + let c = best_compilations(policy_cache, subs[1].1.as_ref(), rw * sat_prob, dissat_prob)?; + + insert_ternary(policy_cache, &a1, &b2, &c, lw, rw)?; + insert_ternary(policy_cache, &b1, &a2, &c, lw, rw)?; + }; + if let (_, Concrete::And(x)) = (&subs[0].1.as_ref(), subs[1].1.as_ref()) { + let a1 = best_compilations( + policy_cache, + x[0].as_ref(), + rw * sat_prob, + Some((lw * sat_prob).conditional_add(dissat_prob)), + )?; + let a2 = best_compilations(policy_cache, x[0].as_ref(), rw * sat_prob, None)?; + + let b1 = best_compilations( + policy_cache, + x[1].as_ref(), + rw * sat_prob, + Some((lw * sat_prob).conditional_add(dissat_prob)), + )?; + let b2 = best_compilations(policy_cache, x[1].as_ref(), rw * sat_prob, None)?; + + let c = best_compilations(policy_cache, subs[0].1.as_ref(), lw * sat_prob, dissat_prob)?; + + insert_ternary(policy_cache, &a1, &b2, &c, rw, lw)?; + insert_ternary(policy_cache, &b1, &a2, &c, rw, lw)?; + }; + + let dissat_probs = |w: PositiveF64| -> Vec> { + vec![ + Some((w * sat_prob).conditional_add(dissat_prob)), + Some(w * sat_prob), + dissat_prob, + None, + ] + }; + + let mut l_comp = vec![]; + let mut r_comp = vec![]; + + for dissat_prob in dissat_probs(rw).iter() { + let l = best_compilations(policy_cache, subs[0].1.as_ref(), lw * sat_prob, *dissat_prob)?; + l_comp.push(l); + } + + for dissat_prob in dissat_probs(lw).iter() { + let r = best_compilations(policy_cache, subs[1].1.as_ref(), rw * sat_prob, *dissat_prob)?; + r_comp.push(r); + } + + let mut insert_binary = |left: &BTreeMap<_, _>, + right: &BTreeMap<_, _>, + lw: PositiveF64, + rw: PositiveF64, + combinator: fn(&_, &_, _, _) -> Result<_, _>| + -> Result<(), CompilerError> { + for l in left.values() { + for r in right.values() { + if let Ok(new_ext) = combinator(l, r, lw, rw) { + insert_best_wrapped(policy_cache, policy, ret, new_ext, sat_prob, dissat_prob)?; + } + } + } + Ok(()) + }; + + insert_binary(&l_comp[0], &r_comp[0], lw, rw, AstElemExt::or_b)?; + insert_binary(&r_comp[0], &l_comp[0], rw, lw, AstElemExt::or_b)?; + insert_binary(&l_comp[0], &r_comp[2], lw, rw, AstElemExt::or_d)?; + insert_binary(&r_comp[0], &l_comp[2], rw, lw, AstElemExt::or_d)?; + insert_binary(&l_comp[1], &r_comp[3], lw, rw, AstElemExt::or_c)?; + insert_binary(&r_comp[1], &l_comp[3], rw, lw, AstElemExt::or_c)?; + insert_binary(&l_comp[2], &r_comp[3], lw, rw, AstElemExt::or_i)?; + insert_binary(&r_comp[3], &l_comp[2], rw, lw, AstElemExt::or_i)?; + insert_binary(&l_comp[3], &r_comp[2], lw, rw, AstElemExt::or_i)?; + insert_binary(&r_comp[2], &l_comp[3], rw, lw, AstElemExt::or_i)?; + + Ok(()) +} + #[cfg(feature = "std")] impl error::Error for CompilerError { fn cause(&self) -> Option<&dyn error::Error> { @@ -128,11 +250,6 @@ impl From for CompilerError { fn from(e: policy::concrete::PolicyError) -> Self { Self::PolicyError(e) } } -/// Hash required for using OrdF64 as key for hashmap -impl hash::Hash for OrdF64 { - fn hash(&self, state: &mut H) { self.0.to_bits().hash(state); } -} - /// Compilation key: This represents the state of the best possible compilation /// of a given policy(implicitly keyed). #[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord, Hash)] @@ -150,7 +267,7 @@ struct CompilationKey { /// The probability of dissatisfaction of the compilation of the policy. Note /// that all possible compilations of a (sub)policy have the same sat-prob /// and only differ in dissat_prob. - dissat_prob: Option, + dissat_prob: Option, } impl CompilationKey { @@ -163,17 +280,13 @@ impl CompilationKey { } /// Helper to create compilation key from components - fn from_type(ty: Type, expensive_verify: bool, dissat_prob: Option) -> Self { - Self { ty, expensive_verify, dissat_prob: dissat_prob.map(OrdF64) } + fn from_type(ty: Type, expensive_verify: bool, dissat_prob: Option) -> Self { + Self { ty, expensive_verify, dissat_prob } } } #[derive(Copy, Clone, Debug)] struct CompilerExtData { - /// If this node is the direct child of a disjunction, this field must - /// have the probability of its branch being taken. Otherwise it is ignored. - /// All functions initialize it to `None`. - branch_prob: Option, /// The number of bytes needed to satisfy the fragment in segwit format /// (total length of all witness pushes, plus their own length prefixes) sat_cost: f64, @@ -184,13 +297,12 @@ struct CompilerExtData { } impl CompilerExtData { - const TRUE: Self = Self { branch_prob: None, sat_cost: 0.0, dissat_cost: None }; + const TRUE: Self = Self { sat_cost: 0.0, dissat_cost: None }; - const FALSE: Self = Self { branch_prob: None, sat_cost: f64::MAX, dissat_cost: Some(0.0) }; + const FALSE: Self = Self { sat_cost: f64::MAX, dissat_cost: Some(0.0) }; fn pk_k() -> Self { Self { - branch_prob: None, sat_cost: match Ctx::sig_type() { SigType::Ecdsa => 73.0, SigType::Schnorr => 1.0 /* */ + 64.0 /* sig */ + 1.0, /* */ @@ -201,7 +313,6 @@ impl CompilerExtData { fn pk_h() -> Self { Self { - branch_prob: None, sat_cost: match Ctx::sig_type() { SigType::Ecdsa => 73.0 + 34.0, SigType::Schnorr => 66.0 + 33.0, @@ -215,79 +326,47 @@ impl CompilerExtData { } } - fn multi(k: usize, _n: usize) -> Self { - Self { - branch_prob: None, - sat_cost: 1.0 + 73.0 * k as f64, - dissat_cost: Some(1.0 * (k + 1) as f64), - } - } - - fn sortedmulti(k: usize, _n: usize) -> Self { - Self { - branch_prob: None, - sat_cost: 1.0 + 73.0 * k as f64, - dissat_cost: Some(1.0 * (k + 1) as f64), - } + fn multi(k: usize) -> Self { + Self { sat_cost: 1.0 + 73.0 * k as f64, dissat_cost: Some(1.0 * (k + 1) as f64) } } fn multi_a(k: usize, n: usize) -> Self { Self { - branch_prob: None, sat_cost: 66.0 * k as f64 + (n - k) as f64, dissat_cost: Some(n as f64), /* ... := 0x00 ... 0x00 (n times) */ } } - fn sortedmulti_a(k: usize, n: usize) -> Self { Self::multi_a(k, n) } + fn hash() -> Self { Self { sat_cost: 33.0, dissat_cost: Some(33.0) } } - fn hash() -> Self { Self { branch_prob: None, sat_cost: 33.0, dissat_cost: Some(33.0) } } + fn time() -> Self { Self { sat_cost: 0.0, dissat_cost: None } } - fn time() -> Self { Self { branch_prob: None, sat_cost: 0.0, dissat_cost: None } } + fn cast_alt(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } } - fn cast_alt(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } - } - - fn cast_swap(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } - } + fn cast_swap(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } } - fn cast_check(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } - } + fn cast_check(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } } - fn cast_dupif(self) -> Self { - Self { branch_prob: None, sat_cost: 2.0 + self.sat_cost, dissat_cost: Some(1.0) } - } + fn cast_dupif(self) -> Self { Self { sat_cost: 2.0 + self.sat_cost, dissat_cost: Some(1.0) } } - fn cast_verify(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: None } - } + fn cast_verify(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: None } } - fn cast_nonzero(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: Some(1.0) } - } + fn cast_nonzero(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: Some(1.0) } } fn cast_zeronotequal(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } + Self { sat_cost: self.sat_cost, dissat_cost: self.dissat_cost } } - fn cast_true(self) -> Self { - Self { branch_prob: None, sat_cost: self.sat_cost, dissat_cost: None } - } + fn cast_true(self) -> Self { Self { sat_cost: self.sat_cost, dissat_cost: None } } fn cast_unlikely(self) -> Self { - Self { branch_prob: None, sat_cost: 2.0 + self.sat_cost, dissat_cost: Some(1.0) } + Self { sat_cost: 2.0 + self.sat_cost, dissat_cost: Some(1.0) } } - fn cast_likely(self) -> Self { - Self { branch_prob: None, sat_cost: 1.0 + self.sat_cost, dissat_cost: Some(2.0) } - } + fn cast_likely(self) -> Self { Self { sat_cost: 1.0 + self.sat_cost, dissat_cost: Some(2.0) } } fn and_b(left: Self, right: Self) -> Self { Self { - branch_prob: None, sat_cost: left.sat_cost + right.sat_cost, dissat_cost: match (left.dissat_cost, right.dissat_cost) { (Some(l), Some(r)) => Some(l + r), @@ -297,63 +376,41 @@ impl CompilerExtData { } fn and_v(left: Self, right: Self) -> Self { - Self { branch_prob: None, sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } + Self { sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } } - fn or_b(l: Self, r: Self) -> Self { - let lprob = l - .branch_prob - .expect("BUG: left branch prob must be set for disjunctions"); - let rprob = r - .branch_prob - .expect("BUG: right branch prob must be set for disjunctions"); + fn and_n(left: Self, right: Self) -> Self { + Self { sat_cost: left.sat_cost + right.sat_cost, dissat_cost: None } + } + + fn or_b(l: Self, r: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { Self { - branch_prob: None, - sat_cost: lprob * (l.sat_cost + r.dissat_cost.unwrap()) - + rprob * (r.sat_cost + l.dissat_cost.unwrap()), + sat_cost: f64::from(lprob) * (l.sat_cost + r.dissat_cost.unwrap()) + + f64::from(rprob) * (r.sat_cost + l.dissat_cost.unwrap()), dissat_cost: Some(l.dissat_cost.unwrap() + r.dissat_cost.unwrap()), } } - fn or_d(l: Self, r: Self) -> Self { - let lprob = l - .branch_prob - .expect("BUG: left branch prob must be set for disjunctions"); - let rprob = r - .branch_prob - .expect("BUG: right branch prob must be set for disjunctions"); + fn or_d(l: Self, r: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { Self { - branch_prob: None, - sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()), + sat_cost: f64::from(lprob) * l.sat_cost + + f64::from(rprob) * (r.sat_cost + l.dissat_cost.unwrap()), dissat_cost: r.dissat_cost.map(|rd| l.dissat_cost.unwrap() + rd), } } - fn or_c(l: Self, r: Self) -> Self { - let lprob = l - .branch_prob - .expect("BUG: left branch prob must be set for disjunctions"); - let rprob = r - .branch_prob - .expect("BUG: right branch prob must be set for disjunctions"); + fn or_c(l: Self, r: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { Self { - branch_prob: None, - sat_cost: lprob * l.sat_cost + rprob * (r.sat_cost + l.dissat_cost.unwrap()), + sat_cost: f64::from(lprob) * l.sat_cost + + f64::from(rprob) * (r.sat_cost + l.dissat_cost.unwrap()), dissat_cost: None, } } #[allow(clippy::manual_map)] // Complex if/let is better as is. - fn or_i(l: Self, r: Self) -> Self { - let lprob = l - .branch_prob - .expect("BUG: left branch prob must be set for disjunctions"); - let rprob = r - .branch_prob - .expect("BUG: right branch prob must be set for disjunctions"); + fn or_i(l: Self, r: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { Self { - branch_prob: None, - sat_cost: lprob * (2.0 + l.sat_cost) + rprob * (1.0 + r.sat_cost), + sat_cost: f64::from(lprob) * (2.0 + l.sat_cost) + f64::from(rprob) * (1.0 + r.sat_cost), dissat_cost: if let (Some(ldis), Some(rdis)) = (l.dissat_cost, r.dissat_cost) { if (2.0 + ldis) > (1.0 + rdis) { Some(1.0 + rdis) @@ -370,18 +427,13 @@ impl CompilerExtData { } } - fn and_or(a: Self, b: Self, c: Self) -> Self { - let aprob = a.branch_prob.expect("andor, a prob must be set"); - let bprob = b.branch_prob.expect("andor, b prob must be set"); - let cprob = c.branch_prob.expect("andor, c prob must be set"); - + fn and_or(a: Self, b: Self, c: Self, lprob: PositiveF64, rprob: PositiveF64) -> Self { let adis = a .dissat_cost .expect("BUG: and_or first arg(a) must be dissatisfiable"); - debug_assert_eq!(aprob, bprob); //A and B must have same branch prob. Self { - branch_prob: None, - sat_cost: aprob * (a.sat_cost + b.sat_cost) + cprob * (adis + c.sat_cost), + sat_cost: f64::from(lprob) * (a.sat_cost + b.sat_cost) + + f64::from(rprob) * (adis + c.sat_cost), dissat_cost: c.dissat_cost.map(|cdis| adis + cdis), } } @@ -399,110 +451,12 @@ impl CompilerExtData { dissat_cost += sub.dissat_cost.unwrap(); } Self { - branch_prob: None, sat_cost: sat_cost * k_over_n + dissat_cost * (1.0 - k_over_n), dissat_cost: Some(dissat_cost), } } } -impl CompilerExtData { - /// Compute the type of a fragment, given a function to look up - /// the types of its children. - fn type_check_with_child(fragment: &Terminal, child: C) -> Self - where - C: Fn(usize) -> Self, - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - let get_child = |_sub, n| child(n); - Self::type_check_common(fragment, get_child) - } - - /// Compute the type of a fragment. - fn type_check(fragment: &Terminal) -> Self - where - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - let check_child = |sub, _n| Self::type_check(sub); - Self::type_check_common(fragment, check_child) - } - - /// Compute the type of a fragment, given a function to look up - /// the types of its children, if available and relevant for the - /// given fragment - fn type_check_common<'a, Pk, Ctx, C>(fragment: &'a Terminal, get_child: C) -> Self - where - C: Fn(&'a Terminal, usize) -> Self, - Pk: MiniscriptKey, - Ctx: ScriptContext, - { - match *fragment { - Terminal::True => Self::TRUE, - Terminal::False => Self::FALSE, - Terminal::PkK(..) => Self::pk_k::(), - Terminal::PkH(..) | Terminal::RawPkH(..) => Self::pk_h::(), - Terminal::Multi(ref thresh) => Self::multi(thresh.k(), thresh.n()), - Terminal::SortedMulti(ref thresh) => Self::sortedmulti(thresh.k(), thresh.n()), - Terminal::MultiA(ref thresh) => Self::multi_a(thresh.k(), thresh.n()), - Terminal::SortedMultiA(ref thresh) => Self::sortedmulti_a(thresh.k(), thresh.n()), - Terminal::After(_) => Self::time(), - Terminal::Older(_) => Self::time(), - Terminal::Sha256(..) => Self::hash(), - Terminal::Hash256(..) => Self::hash(), - Terminal::Ripemd160(..) => Self::hash(), - Terminal::Hash160(..) => Self::hash(), - Terminal::Alt(ref sub) => Self::cast_alt(get_child(&sub.node, 0)), - Terminal::Swap(ref sub) => Self::cast_swap(get_child(&sub.node, 0)), - Terminal::Check(ref sub) => Self::cast_check(get_child(&sub.node, 0)), - Terminal::DupIf(ref sub) => Self::cast_dupif(get_child(&sub.node, 0)), - Terminal::Verify(ref sub) => Self::cast_verify(get_child(&sub.node, 0)), - Terminal::NonZero(ref sub) => Self::cast_nonzero(get_child(&sub.node, 0)), - Terminal::ZeroNotEqual(ref sub) => Self::cast_zeronotequal(get_child(&sub.node, 0)), - Terminal::AndB(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::and_b(ltype, rtype) - } - Terminal::AndV(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::and_v(ltype, rtype) - } - Terminal::OrB(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::or_b(ltype, rtype) - } - Terminal::OrD(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::or_d(ltype, rtype) - } - Terminal::OrC(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::or_c(ltype, rtype) - } - Terminal::OrI(ref l, ref r) => { - let ltype = get_child(&l.node, 0); - let rtype = get_child(&r.node, 1); - Self::or_i(ltype, rtype) - } - Terminal::AndOr(ref a, ref b, ref c) => { - let atype = get_child(&a.node, 0); - let btype = get_child(&b.node, 1); - let ctype = get_child(&c.node, 2); - Self::and_or(atype, btype, ctype) - } - Terminal::Thresh(ref thresh) => { - Self::threshold(thresh.k(), thresh.n(), |n| get_child(&thresh.data()[n].node, n)) - } - } - } -} - /// Miniscript AST fragment with additional data needed by the compiler #[derive(Clone, Debug)] struct AstElemExt { @@ -516,11 +470,11 @@ impl AstElemExt { /// Compute a 1-dimensional cost, given a probability of satisfaction /// and a probability of dissatisfaction; if `dissat_prob` is `None` /// then it is assumed that dissatisfaction never occurs - fn cost_1d(&self, sat_prob: f64, dissat_prob: Option) -> f64 { + fn cost_1d(&self, sat_prob: PositiveF64, dissat_prob: Option) -> f64 { self.ms.ext.pk_cost as f64 - + self.comp_ext_data.sat_cost * sat_prob + + self.comp_ext_data.sat_cost * f64::from(sat_prob) + match (dissat_prob, self.comp_ext_data.dissat_cost) { - (Some(prob), Some(cost)) => prob * cost, + (Some(prob), Some(cost)) => f64::from(prob) * cost, (Some(_), None) => f64::INFINITY, (None, Some(_)) => 0.0, (None, None) => 0.0, @@ -529,42 +483,211 @@ impl AstElemExt { } impl AstElemExt { - fn terminal(ms: Miniscript) -> Self { - Self { comp_ext_data: CompilerExtData::type_check(ms.as_inner()), ms: Arc::new(ms) } + fn unsatisfiable() -> Self { + Self { ms: Arc::new(Miniscript::FALSE), comp_ext_data: CompilerExtData::FALSE } } - fn binary(ast: Terminal, l: &Self, r: &Self) -> Result { - let lookup_ext = |n| match n { - 0 => l.comp_ext_data, - 1 => r.comp_ext_data, - _ => unreachable!(), - }; - //Types and ExtData are already cached and stored in children. So, we can - //type_check without cache. For Compiler extra data, we supply a cache. - let ty = types::Type::type_check(&ast)?; - let ext = types::ExtData::type_check(&ast); - let comp_ext_data = CompilerExtData::type_check_with_child(&ast, lookup_ext); + fn trivial() -> Self { + Self { ms: Arc::new(Miniscript::TRUE), comp_ext_data: CompilerExtData::TRUE } + } + + fn pk_h(key: Pk) -> Self { + Self { + ms: Arc::new(Miniscript::pk_h(key)), + comp_ext_data: CompilerExtData::pk_h::(), + } + } + + fn pk_k(key: Pk) -> Self { + Self { + ms: Arc::new(Miniscript::pk_k(key)), + comp_ext_data: CompilerExtData::pk_k::(), + } + } + + fn after(t: crate::AbsLockTime) -> Self { + Self { ms: Arc::new(Miniscript::after(t)), comp_ext_data: CompilerExtData::time() } + } + + fn older(t: crate::RelLockTime) -> Self { + Self { ms: Arc::new(Miniscript::older(t)), comp_ext_data: CompilerExtData::time() } + } + + fn sha256(h: Pk::Sha256) -> Self { + Self { ms: Arc::new(Miniscript::sha256(h)), comp_ext_data: CompilerExtData::hash() } + } + + fn hash256(h: Pk::Hash256) -> Self { + Self { ms: Arc::new(Miniscript::hash256(h)), comp_ext_data: CompilerExtData::hash() } + } + + fn ripemd160(h: Pk::Ripemd160) -> Self { + Self { ms: Arc::new(Miniscript::ripemd160(h)), comp_ext_data: CompilerExtData::hash() } + } + + fn hash160(h: Pk::Hash160) -> Self { + Self { ms: Arc::new(Miniscript::hash160(h)), comp_ext_data: CompilerExtData::hash() } + } + + fn multi(thresh: crate::Threshold) -> Self { + let k = thresh.k(); + Self { + ms: Arc::new(Miniscript::multi(thresh)), + comp_ext_data: CompilerExtData::multi(k), + } + } + + fn multi_a(thresh: crate::Threshold) -> Self { + let k = thresh.k(); + let n = thresh.n(); + Self { + ms: Arc::new(Miniscript::multi_a(thresh)), + comp_ext_data: CompilerExtData::multi_a(k, n), + } + } + + /// Helper functions to compose two Miniscript fragments, where we assume + /// by construction that all validation parameters are upheld. + fn compose_typeck_only( + term: Terminal, + ) -> Result>, types::Error> { + let ty = types::Type::type_check(&term)?; + let ext = types::ExtData::type_check(&term); + Ok(Arc::new(Miniscript::from_components_unchecked(term, ty, ext))) + } + + fn and_b(left: &Self, right: &Self) -> Result { Ok(Self { - ms: Arc::new(Miniscript::from_components_unchecked(ast, ty, ext)), - comp_ext_data, + ms: Self::compose_typeck_only(Terminal::AndB( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::and_b(left.comp_ext_data, right.comp_ext_data), }) } - fn ternary(ast: Terminal, a: &Self, b: &Self, c: &Self) -> Result { - let lookup_ext = |n| match n { - 0 => a.comp_ext_data, - 1 => b.comp_ext_data, - 2 => c.comp_ext_data, - _ => unreachable!(), - }; - //Types and ExtData are already cached and stored in children. So, we can - //type_check without cache. For Compiler extra data, we supply a cache. - let ty = types::Type::type_check(&ast)?; - let ext = types::ExtData::type_check(&ast); - let comp_ext_data = CompilerExtData::type_check_with_child(&ast, lookup_ext); + fn and_v(left: &Self, right: &Self) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::AndV( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::and_v(left.comp_ext_data, right.comp_ext_data), + }) + } + + /// and_n(a,b) == andor(a,b,0) is a conjunction of a and b + fn and_n(left: &Self, right: &Self) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::AndOr( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + Arc::new(Miniscript::FALSE), + ))?, + comp_ext_data: CompilerExtData::and_n(left.comp_ext_data, right.comp_ext_data), + }) + } + + fn and_or( + a: &Self, + b: &Self, + c: &Self, + l_weight: PositiveF64, + r_weight: PositiveF64, + ) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::AndOr( + Arc::clone(&a.ms), + Arc::clone(&b.ms), + Arc::clone(&c.ms), + ))?, + comp_ext_data: CompilerExtData::and_or( + a.comp_ext_data, + b.comp_ext_data, + c.comp_ext_data, + l_weight, + r_weight, + ), + }) + } + + fn or_b( + left: &Self, + right: &Self, + l_weight: PositiveF64, + r_weight: PositiveF64, + ) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::OrB( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::or_b( + left.comp_ext_data, + right.comp_ext_data, + l_weight, + r_weight, + ), + }) + } + + fn or_d( + left: &Self, + right: &Self, + l_weight: PositiveF64, + r_weight: PositiveF64, + ) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::OrD( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::or_d( + left.comp_ext_data, + right.comp_ext_data, + l_weight, + r_weight, + ), + }) + } + + fn or_c( + left: &Self, + right: &Self, + l_weight: PositiveF64, + r_weight: PositiveF64, + ) -> Result { + Ok(Self { + ms: Self::compose_typeck_only(Terminal::OrC( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::or_c( + left.comp_ext_data, + right.comp_ext_data, + l_weight, + r_weight, + ), + }) + } + + fn or_i( + left: &Self, + right: &Self, + l_weight: PositiveF64, + r_weight: PositiveF64, + ) -> Result { Ok(Self { - ms: Arc::new(Miniscript::from_components_unchecked(ast, ty, ext)), - comp_ext_data, + ms: Self::compose_typeck_only(Terminal::OrI( + Arc::clone(&left.ms), + Arc::clone(&right.ms), + ))?, + comp_ext_data: CompilerExtData::or_i( + left.comp_ext_data, + right.comp_ext_data, + l_weight, + r_weight, + ), }) } } @@ -668,8 +791,8 @@ fn all_casts() -> [Cast; 10] { fn insert_elem( map: &mut BTreeMap>, elem: AstElemExt, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> bool { // We check before compiling that non-malleable satisfactions exist, and it appears that // there are no cases when malleable satisfactions beat non-malleable ones (and if there @@ -719,8 +842,8 @@ fn insert_elem( fn insert_elem_closure( map: &mut BTreeMap>, astelem_ext: AstElemExt, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) { let mut cast_stack: VecDeque> = VecDeque::new(); if insert_elem(map, astelem_ext.clone(), sat_prob, dissat_prob) { @@ -755,8 +878,8 @@ fn insert_best_wrapped( policy: &Concrete, map: &mut BTreeMap>, data: AstElemExt, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result<(), CompilerError> { insert_elem_closure(map, data, sat_prob, dissat_prob); @@ -779,17 +902,15 @@ fn insert_best_wrapped( fn best_compilations( policy_cache: &mut PolicyCache, policy: &Concrete, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result>, CompilerError> where Pk: MiniscriptKey, Ctx: ScriptContext, { //Check the cache for hits - let ord_sat_prob = OrdF64(sat_prob); - let ord_dissat_prob = dissat_prob.map(OrdF64); - if let Some(ret) = policy_cache.get(&(policy.clone(), ord_sat_prob, ord_dissat_prob)) { + if let Some(ret) = policy_cache.get(&(policy.clone(), sat_prob, dissat_prob)) { return Ok(ret.clone()); } @@ -801,179 +922,70 @@ where insert_best_wrapped(policy_cache, policy, &mut ret, $x, sat_prob, dissat_prob)? }; } - macro_rules! compile_binary { - ($l:expr, $r:expr, $w: expr, $f: expr) => { - compile_binary(policy_cache, policy, &mut ret, $l, $r, $w, sat_prob, dissat_prob, $f)? - }; - } - macro_rules! compile_tern { - ($a:expr, $b:expr, $c: expr, $w: expr) => { - compile_tern(policy_cache, policy, &mut ret, $a, $b, $c, $w, sat_prob, dissat_prob)? - }; - } match *policy { Concrete::Unsatisfiable => { - insert_wrap!(AstElemExt::terminal(Miniscript::FALSE)); + insert_wrap!(AstElemExt::unsatisfiable()); } Concrete::Trivial => { - insert_wrap!(AstElemExt::terminal(Miniscript::TRUE)); + insert_wrap!(AstElemExt::trivial()); } Concrete::Key(ref pk) => { - insert_wrap!(AstElemExt::terminal(Miniscript::pk_h(pk.clone()))); - insert_wrap!(AstElemExt::terminal(Miniscript::pk_k(pk.clone()))); - } - Concrete::After(n) => insert_wrap!(AstElemExt::terminal(Miniscript::after(n))), - Concrete::Older(n) => insert_wrap!(AstElemExt::terminal(Miniscript::older(n))), - Concrete::Sha256(ref hash) => { - insert_wrap!(AstElemExt::terminal(Miniscript::sha256(hash.clone()))) + insert_wrap!(AstElemExt::pk_h(pk.clone())); + insert_wrap!(AstElemExt::pk_k(pk.clone())); } + Concrete::After(n) => insert_wrap!(AstElemExt::after(n)), + Concrete::Older(n) => insert_wrap!(AstElemExt::older(n)), + Concrete::Sha256(ref hash) => insert_wrap!(AstElemExt::sha256(hash.clone())), // Satisfaction-cost + script-cost - Concrete::Hash256(ref hash) => { - insert_wrap!(AstElemExt::terminal(Miniscript::hash256(hash.clone()))) - } - Concrete::Ripemd160(ref hash) => { - insert_wrap!(AstElemExt::terminal(Miniscript::ripemd160(hash.clone()))) - } - Concrete::Hash160(ref hash) => { - insert_wrap!(AstElemExt::terminal(Miniscript::hash160(hash.clone()))) - } + Concrete::Hash256(ref hash) => insert_wrap!(AstElemExt::hash256(hash.clone())), + Concrete::Ripemd160(ref hash) => insert_wrap!(AstElemExt::ripemd160(hash.clone())), + Concrete::Hash160(ref hash) => insert_wrap!(AstElemExt::hash160(hash.clone())), Concrete::And(ref subs) => { assert_eq!(subs.len(), 2, "and takes 2 args"); - let mut left = - best_compilations(policy_cache, subs[0].as_ref(), sat_prob, dissat_prob)?; - let mut right = - best_compilations(policy_cache, subs[1].as_ref(), sat_prob, dissat_prob)?; - let mut q_zero_right = - best_compilations(policy_cache, subs[1].as_ref(), sat_prob, None)?; - let mut q_zero_left = - best_compilations(policy_cache, subs[0].as_ref(), sat_prob, None)?; - - compile_binary!(&mut left, &mut right, [1.0, 1.0], Terminal::AndB); - compile_binary!(&mut right, &mut left, [1.0, 1.0], Terminal::AndB); - compile_binary!(&mut left, &mut right, [1.0, 1.0], Terminal::AndV); - compile_binary!(&mut right, &mut left, [1.0, 1.0], Terminal::AndV); - let mut zero_comp = BTreeMap::new(); - zero_comp.insert( - CompilationKey::from_type(Type::FALSE, ExtData::FALSE.has_free_verify, dissat_prob), - AstElemExt::terminal(Miniscript::FALSE), - ); - compile_tern!(&mut left, &mut q_zero_right, &mut zero_comp, [1.0, 0.0]); - compile_tern!(&mut right, &mut q_zero_left, &mut zero_comp, [1.0, 0.0]); + let left = best_compilations(policy_cache, subs[0].as_ref(), sat_prob, dissat_prob)?; + let right = best_compilations(policy_cache, subs[1].as_ref(), sat_prob, dissat_prob)?; + let q_zero_right = best_compilations(policy_cache, subs[1].as_ref(), sat_prob, None)?; + let q_zero_left = best_compilations(policy_cache, subs[0].as_ref(), sat_prob, None)?; + + let mut insert_binary = |left: &BTreeMap<_, _>, + right: &BTreeMap<_, _>, + combinator: fn(&_, &_) -> Result<_, _>| + -> Result<(), CompilerError> { + for l in left.values() { + for r in right.values() { + if let Ok(new_ext) = combinator(l, r) { + insert_best_wrapped( + policy_cache, + policy, + &mut ret, + new_ext, + sat_prob, + dissat_prob, + )?; + } + } + } + Ok(()) + }; + insert_binary(&left, &right, AstElemExt::and_b)?; + // Do a separate loop with 'l' and 'r' swapped; we could combine the loops, + // but this would sometimes result in compiling e.g. and(pk(A),pk(B)) into + // an and with A and B swapped, which is surprising to the user since the + // cost is the same with or without the swap. + insert_binary(&right, &left, AstElemExt::and_b)?; + insert_binary(&left, &right, AstElemExt::and_v)?; + insert_binary(&right, &left, AstElemExt::and_v)?; + insert_binary(&left, &q_zero_right, AstElemExt::and_n)?; + insert_binary(&right, &q_zero_left, AstElemExt::and_n)?; } Concrete::Or(ref subs) => { - let total = (subs[0].0 + subs[1].0) as f64; - let lw = subs[0].0 as f64 / total; - let rw = subs[1].0 as f64 / total; - - //and-or - if let (Concrete::And(x), _) = (subs[0].1.as_ref(), subs[1].1.as_ref()) { - let mut a1 = best_compilations( - policy_cache, - x[0].as_ref(), - lw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), - )?; - let mut a2 = best_compilations(policy_cache, x[0].as_ref(), lw * sat_prob, None)?; - - let mut b1 = best_compilations( - policy_cache, - x[1].as_ref(), - lw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + rw * sat_prob), - )?; - let mut b2 = best_compilations(policy_cache, x[1].as_ref(), lw * sat_prob, None)?; - - let mut c = best_compilations( - policy_cache, - subs[1].1.as_ref(), - rw * sat_prob, - dissat_prob, - )?; - - compile_tern!(&mut a1, &mut b2, &mut c, [lw, rw]); - compile_tern!(&mut b1, &mut a2, &mut c, [lw, rw]); - }; - if let (_, Concrete::And(x)) = (&subs[0].1.as_ref(), subs[1].1.as_ref()) { - let mut a1 = best_compilations( - policy_cache, - x[0].as_ref(), - rw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), - )?; - let mut a2 = best_compilations(policy_cache, x[0].as_ref(), rw * sat_prob, None)?; - - let mut b1 = best_compilations( - policy_cache, - x[1].as_ref(), - rw * sat_prob, - Some(dissat_prob.unwrap_or(0 as f64) + lw * sat_prob), - )?; - let mut b2 = best_compilations(policy_cache, x[1].as_ref(), rw * sat_prob, None)?; - - let mut c = best_compilations( - policy_cache, - subs[0].1.as_ref(), - lw * sat_prob, - dissat_prob, - )?; - - compile_tern!(&mut a1, &mut b2, &mut c, [rw, lw]); - compile_tern!(&mut b1, &mut a2, &mut c, [rw, lw]); - }; - - let dissat_probs = |w: f64| -> Vec> { - vec![ - Some(dissat_prob.unwrap_or(0 as f64) + w * sat_prob), - Some(w * sat_prob), - dissat_prob, - None, - ] - }; - - let mut l_comp = vec![]; - let mut r_comp = vec![]; - - for dissat_prob in dissat_probs(rw).iter() { - let l = best_compilations( - policy_cache, - subs[0].1.as_ref(), - lw * sat_prob, - *dissat_prob, - )?; - l_comp.push(l); - } - - for dissat_prob in dissat_probs(lw).iter() { - let r = best_compilations( - policy_cache, - subs[1].1.as_ref(), - rw * sat_prob, - *dissat_prob, - )?; - r_comp.push(r); - } - - // or(sha256, pk) - compile_binary!(&mut l_comp[0], &mut r_comp[0], [lw, rw], Terminal::OrB); - compile_binary!(&mut r_comp[0], &mut l_comp[0], [rw, lw], Terminal::OrB); - - compile_binary!(&mut l_comp[0], &mut r_comp[2], [lw, rw], Terminal::OrD); - compile_binary!(&mut r_comp[0], &mut l_comp[2], [rw, lw], Terminal::OrD); - - compile_binary!(&mut l_comp[1], &mut r_comp[3], [lw, rw], Terminal::OrC); - compile_binary!(&mut r_comp[1], &mut l_comp[3], [rw, lw], Terminal::OrC); - - compile_binary!(&mut l_comp[2], &mut r_comp[3], [lw, rw], Terminal::OrI); - compile_binary!(&mut r_comp[2], &mut l_comp[3], [rw, lw], Terminal::OrI); - - compile_binary!(&mut l_comp[3], &mut r_comp[2], [lw, rw], Terminal::OrI); - compile_binary!(&mut r_comp[3], &mut l_comp[2], [rw, lw], Terminal::OrI); + best_compilations_or(&mut ret, policy_cache, policy, subs, sat_prob, dissat_prob)?; } Concrete::Thresh(ref thresh) => { let k = thresh.k(); let n = thresh.n(); - let k_over_n = k as f64 / n as f64; + let k_over_n = PositiveF64::k_over_n(thresh); let mut sub_ext_data = Vec::with_capacity(n); @@ -981,10 +993,20 @@ where let mut best_ws = Vec::with_capacity(n); let mut min_value = (0, f64::INFINITY); + + let total_sat_prob = sat_prob * k_over_n; + // This match can be written in terms of nested conditional_adds() but seems less clear that way. + let total_dissat_prob = match (dissat_prob, PositiveF64::one_minus_k_over_n(thresh)) { + (Some(dp), Some(kn)) => Some(dp + kn * sat_prob), + (Some(dp), None) => Some(dp), + (None, Some(kn)) => Some(kn * sat_prob), + (None, None) => None, + }; + for (i, ast) in thresh.iter().enumerate() { - let sp = sat_prob * k_over_n; - //Expressions must be dissatisfiable - let dp = Some(dissat_prob.unwrap_or(0 as f64) + (1.0 - k_over_n) * sat_prob); + let sp = total_sat_prob; + let dp = total_dissat_prob; + let be = best(types::Base::B, policy_cache, ast.as_ref(), sp, dp)?; let bw = best(types::Base::W, policy_cache, ast.as_ref(), sp, dp)?; @@ -1043,12 +1065,12 @@ where match Ctx::sig_type() { SigType::Schnorr => { if let Ok(pk_thresh) = pk_thresh.set_maximum() { - insert_wrap!(AstElemExt::terminal(Miniscript::multi_a(pk_thresh))) + insert_wrap!(AstElemExt::multi_a(pk_thresh)) } } SigType::Ecdsa => { if let Ok(pk_thresh) = pk_thresh.set_maximum() { - insert_wrap!(AstElemExt::terminal(Miniscript::multi(pk_thresh))) + insert_wrap!(AstElemExt::multi(pk_thresh)) } } } @@ -1065,7 +1087,7 @@ where } } for k in ret.keys() { - debug_assert_eq!(k.dissat_prob, ord_dissat_prob); + debug_assert_eq!(k.dissat_prob, dissat_prob); } if ret.is_empty() { // The only reason we are discarding elements out of compiler is because @@ -1076,86 +1098,17 @@ where // before calling this compile function Err(CompilerError::LimitsExceeded) } else { - policy_cache.insert((policy.clone(), ord_sat_prob, ord_dissat_prob), ret.clone()); + policy_cache.insert((policy.clone(), sat_prob, dissat_prob), ret.clone()); Ok(ret) } } -/// Helper function to compile different types of binary fragments. -/// `sat_prob` and `dissat_prob` represent the sat and dissat probabilities of -/// root or. `weights` represent the odds for taking each sub branch -#[allow(clippy::too_many_arguments)] -fn compile_binary( - policy_cache: &mut PolicyCache, - policy: &Concrete, - ret: &mut BTreeMap>, - left_comp: &mut BTreeMap>, - right_comp: &mut BTreeMap>, - weights: [f64; 2], - sat_prob: f64, - dissat_prob: Option, - bin_func: F, -) -> Result<(), CompilerError> -where - Pk: MiniscriptKey, - Ctx: ScriptContext, - F: Fn(Arc>, Arc>) -> Terminal, -{ - for l in left_comp.values_mut() { - let lref = Arc::clone(&l.ms); - for r in right_comp.values_mut() { - let rref = Arc::clone(&r.ms); - let ast = bin_func(Arc::clone(&lref), Arc::clone(&rref)); - l.comp_ext_data.branch_prob = Some(weights[0]); - r.comp_ext_data.branch_prob = Some(weights[1]); - if let Ok(new_ext) = AstElemExt::binary(ast, l, r) { - insert_best_wrapped(policy_cache, policy, ret, new_ext, sat_prob, dissat_prob)?; - } - } - } - Ok(()) -} - -/// Helper function to compile different order of and_or fragments. -/// `sat_prob` and `dissat_prob` represent the sat and dissat probabilities of -/// root and_or node. `weights` represent the odds for taking each sub branch -#[allow(clippy::too_many_arguments)] -fn compile_tern( - policy_cache: &mut PolicyCache, - policy: &Concrete, - ret: &mut BTreeMap>, - a_comp: &mut BTreeMap>, - b_comp: &mut BTreeMap>, - c_comp: &mut BTreeMap>, - weights: [f64; 2], - sat_prob: f64, - dissat_prob: Option, -) -> Result<(), CompilerError> { - for a in a_comp.values_mut() { - let aref = Arc::clone(&a.ms); - for b in b_comp.values_mut() { - let bref = Arc::clone(&b.ms); - for c in c_comp.values_mut() { - let cref = Arc::clone(&c.ms); - let ast = Terminal::AndOr(Arc::clone(&aref), Arc::clone(&bref), Arc::clone(&cref)); - a.comp_ext_data.branch_prob = Some(weights[0]); - b.comp_ext_data.branch_prob = Some(weights[0]); - c.comp_ext_data.branch_prob = Some(weights[1]); - if let Ok(new_ext) = AstElemExt::ternary(ast, a, b, c) { - insert_best_wrapped(policy_cache, policy, ret, new_ext, sat_prob, dissat_prob)?; - } - } - } - } - Ok(()) -} - /// Obtain the best compilation of for p=1.0 and q=0 pub fn best_compilation( policy: &Concrete, ) -> Result, CompilerError> { let mut policy_cache = PolicyCache::::new(); - let x = &*best_t(&mut policy_cache, policy, 1.0, None)?.ms; + let x = &*best_t(&mut policy_cache, policy, PositiveF64::ONE, None)?.ms; if !x.ty.mall.signed { Err(CompilerError::TopLevelSigless) } else if !x.ty.mall.non_malleable { @@ -1169,8 +1122,8 @@ pub fn best_compilation( fn best_t( policy_cache: &mut PolicyCache, policy: &Concrete, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result, CompilerError> where Pk: MiniscriptKey, @@ -1178,11 +1131,9 @@ where { best_compilations(policy_cache, policy, sat_prob, dissat_prob)? .into_iter() - .filter(|&(key, _)| { - key.ty.corr.base == types::Base::B && key.dissat_prob == dissat_prob.map(OrdF64) - }) + .filter(|&(key, _)| key.ty.corr.base == types::Base::B && key.dissat_prob == dissat_prob) .map(|(_, val)| val) - .min_by_key(|ext| OrdF64(ext.cost_1d(sat_prob, dissat_prob))) + .min_by_key(|ext| PositiveF64::new(ext.cost_1d(sat_prob, dissat_prob))) .ok_or(CompilerError::LimitsExceeded) } @@ -1191,8 +1142,8 @@ fn best( basic_type: types::Base, policy_cache: &mut PolicyCache, policy: &Concrete, - sat_prob: f64, - dissat_prob: Option, + sat_prob: PositiveF64, + dissat_prob: Option, ) -> Result, CompilerError> where Pk: MiniscriptKey, @@ -1204,15 +1155,16 @@ where key.ty.corr.base == basic_type && key.ty.corr.unit && val.ms.ty.mall.dissat == types::Dissat::Unique - && key.dissat_prob == dissat_prob.map(OrdF64) + && key.dissat_prob == dissat_prob }) .map(|(_, val)| val) - .min_by_key(|ext| OrdF64(ext.cost_1d(sat_prob, dissat_prob))) + .min_by_key(|ext| PositiveF64::new(ext.cost_1d(sat_prob, dissat_prob))) .ok_or(CompilerError::LimitsExceeded) } #[cfg(test)] mod tests { + use core::num::NonZeroU32; use core::str::FromStr; use bitcoin::blockdata::{opcodes, script}; @@ -1228,6 +1180,9 @@ mod tests { type TapAstElemExt = policy::compiler::AstElemExt; type SegwitMiniScript = Miniscript; + #[allow(unsafe_code)] + const ONE: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(1) }; // can be NonZeroU32::MIN in 1.70 + fn pubkeys_and_a_sig(n: usize) -> (Vec, secp256k1::ecdsa::Signature) { let mut ret = Vec::with_capacity(n); let secp = secp256k1::Secp256k1::new(); @@ -1309,18 +1264,20 @@ mod tests { #[test] fn compile_q() { let policy = SPolicy::from_str("or(1@and(pk(A),pk(B)),127@pk(C))").expect("parsing"); - let compilation: TapAstElemExt = best_t(&mut BTreeMap::new(), &policy, 1.0, None).unwrap(); + let compilation: TapAstElemExt = + best_t(&mut BTreeMap::new(), &policy, PositiveF64::ONE, None).unwrap(); - assert_eq!(compilation.cost_1d(1.0, None), 87.0 + 67.0390625); + assert_eq!(compilation.cost_1d(PositiveF64::ONE, None), 87.0 + 67.0390625); assert_eq!(policy.lift().unwrap().sorted(), compilation.ms.lift().unwrap().sorted()); // compile into taproot context to avoid limit errors let policy = SPolicy::from_str( "and(and(and(or(127@thresh(2,pk(A),pk(B),thresh(2,or(127@pk(A),1@pk(B)),after(100),or(and(pk(C),after(200)),and(pk(D),sha256(66687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925))),pk(E))),1@pk(F)),sha256(66687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925)),or(127@pk(G),1@after(300))),or(127@after(400),pk(H)))" ).expect("parsing"); - let compilation: TapAstElemExt = best_t(&mut BTreeMap::new(), &policy, 1.0, None).unwrap(); + let compilation: TapAstElemExt = + best_t(&mut BTreeMap::new(), &policy, PositiveF64::ONE, None).unwrap(); - assert_eq!(compilation.cost_1d(1.0, None), 433.0 + 275.7909749348958); + assert_eq!(compilation.cost_1d(PositiveF64::ONE, None), 433.0 + 275.7909749348958); assert_eq!(policy.lift().unwrap().sorted(), compilation.ms.lift().unwrap().sorted()); } @@ -1365,14 +1322,14 @@ mod tests { // Liquid policy let policy: BPolicy = Concrete::Or(vec![ ( - 127, + NonZeroU32::new(127).unwrap(), Arc::new(Concrete::Thresh( Threshold::from_iter(3, key_pol[0..5].iter().map(|p| (p.clone()).into())) .unwrap(), )), ), ( - 1, + NonZeroU32::new(1).unwrap(), Arc::new(Concrete::And(vec![ Arc::new(Concrete::Older(RelLockTime::from_height(10000).unwrap())), Arc::new(Concrete::Thresh( @@ -1541,8 +1498,8 @@ mod tests { .collect(); let thresh_res: Result = Concrete::Or(vec![ - (1, Arc::new(Concrete::Thresh(Threshold::and_n(keys_a)))), - (1, Arc::new(Concrete::Thresh(Threshold::and_n(keys_b)))), + (ONE, Arc::new(Concrete::Thresh(Threshold::and_n(keys_a)))), + (ONE, Arc::new(Concrete::Thresh(Threshold::and_n(keys_b)))), ]) .compile(); let script_size = thresh_res.clone().map(|m| m.script_size()); @@ -1612,14 +1569,14 @@ mod tests { // Test that we refuse to compile policies with duplicated keys let (keys, _) = pubkeys_and_a_sig(1); let key = Arc::new(Concrete::Key(keys[0])); - let res = - Concrete::Or(vec![(1, Arc::clone(&key)), (1, Arc::clone(&key))]).compile::(); + let res = Concrete::Or(vec![(ONE, Arc::clone(&key)), (ONE, Arc::clone(&key))]) + .compile::(); assert_eq!( res, Err(CompilerError::PolicyError(policy::concrete::PolicyError::DuplicatePubKeys)) ); // Same for legacy - let res = Concrete::Or(vec![(1, key.clone()), (1, key)]).compile::(); + let res = Concrete::Or(vec![(ONE, key.clone()), (ONE, key)]).compile::(); assert_eq!( res, Err(CompilerError::PolicyError(policy::concrete::PolicyError::DuplicatePubKeys)) diff --git a/src/policy/concrete.rs b/src/policy/concrete.rs index 3e2220df5..dffab644f 100644 --- a/src/policy/concrete.rs +++ b/src/policy/concrete.rs @@ -3,6 +3,7 @@ //! Concrete Policies //! +use core::num::NonZeroU32; use core::{cmp, fmt, str}; #[cfg(feature = "std")] use std::error; @@ -12,7 +13,7 @@ use bitcoin::absolute; use { crate::descriptor::TapTree, crate::miniscript::ScriptContext, - crate::policy::compiler::{self, CompilerError, OrdF64}, + crate::policy::compiler::{self, CompilerError}, crate::Descriptor, crate::Miniscript, crate::Tap, @@ -27,6 +28,8 @@ use crate::prelude::*; use crate::sync::Arc; #[cfg(all(doc, not(feature = "compiler")))] use crate::Descriptor; +#[cfg(feature = "compiler")] +use crate::PositiveF64; use crate::{ AbsLockTime, Error, ForEachKey, FromStrKey, MiniscriptKey, RelLockTime, Threshold, Translator, }; @@ -65,7 +68,7 @@ pub enum Policy { And(Vec>), /// A list of sub-policies, one of which must be satisfied, along with /// relative probabilities for each one. - Or(Vec<(usize, Arc)>), + Or(Vec<(NonZeroU32, Arc)>), /// A set of descriptors, satisfactions must be provided for `k` of them. Thresh(Threshold, 0>), } @@ -165,12 +168,12 @@ impl error::Error for PolicyError { #[cfg(feature = "compiler")] struct TapleafProbabilityIter<'p, Pk: MiniscriptKey> { - stack: Vec<(f64, &'p Policy)>, + stack: Vec<(PositiveF64, &'p Policy)>, } #[cfg(feature = "compiler")] impl<'p, Pk: MiniscriptKey> Iterator for TapleafProbabilityIter<'p, Pk> { - type Item = (f64, &'p Policy); + type Item = (PositiveF64, &'p Policy); fn next(&mut self) -> Option { loop { @@ -178,14 +181,14 @@ impl<'p, Pk: MiniscriptKey> Iterator for TapleafProbabilityIter<'p, Pk> { match top { Policy::Or(ref subs) => { - let total_sub_prob = subs.iter().map(|prob_sub| prob_sub.0).sum::(); - for (sub_prob, sub) in subs.iter().rev() { - let ratio = *sub_prob as f64 / total_sub_prob as f64; + let normalized_iter = + PositiveF64::normalized_iter(subs.iter().map(|x| x.0.into())); + for (ratio, (_, sub)) in normalized_iter.zip(subs.iter()).rev() { self.stack.push((top_prob * ratio, sub)); } } Policy::Thresh(ref thresh) if thresh.is_or() => { - let n64 = thresh.n() as f64; + let n64 = PositiveF64::n(thresh); for sub in thresh.iter().rev() { self.stack.push((top_prob / n64, sub)); } @@ -237,7 +240,7 @@ impl Policy { /// leaf-nodes to [`MAX_COMPILATION_LEAVES`]. #[cfg(feature = "compiler")] fn tapleaf_probability_iter(&self) -> TapleafProbabilityIter<'_, Pk> { - TapleafProbabilityIter { stack: vec![(1.0, self)] } + TapleafProbabilityIter { stack: vec![(PositiveF64::ONE, self)] } } /// Extracts the internal_key from this policy tree. @@ -246,7 +249,7 @@ impl Policy { let internal_key = self .tapleaf_probability_iter() .filter_map(|(prob, ref pol)| match pol { - Self::Key(pk) => Some((OrdF64(prob), pk)), + Self::Key(pk) => Some((prob, pk)), _ => None, }) .max_by_key(|(prob, _)| *prob) @@ -291,14 +294,15 @@ impl Policy { match policy { Self::Trivial => None, policy => { - let mut leaf_compilations: Vec<(OrdF64, Miniscript)> = vec![]; + let mut leaf_compilations: Vec<(PositiveF64, Miniscript)> = + vec![]; for (prob, pol) in policy.tapleaf_probability_iter() { // policy corresponding to the key (replaced by unsatisfiable) is skipped if *pol == Self::Unsatisfiable { continue; } let compilation = compiler::best_compilation::(pol)?; - leaf_compilations.push((OrdF64(prob), compilation)); + leaf_compilations.push((prob, compilation)); } if !leaf_compilations.is_empty() { let tap_tree = with_huffman_tree::(leaf_compilations); @@ -344,13 +348,16 @@ impl Policy { let tap_tree = match policy { Self::Trivial => None, policy => { - let leaves = - policy.enumerate_leaves(1.0, max_leaves, Self::enumerate_pol_native); + let leaves = policy.enumerate_leaves( + PositiveF64::ONE, + max_leaves, + Self::enumerate_pol_native, + ); let n = leaves.len(); if n > max_leaves { return Err(CompilerError::TooManyTapleaves { n, max: max_leaves }); } - let mut leaf_compilations: Vec<(OrdF64, Miniscript)> = vec![]; + let mut leaf_compilations: Vec<(PositiveF64, Miniscript)> = vec![]; for (leaf_idx, (prob, pol)) in leaves.iter().enumerate() { if **pol == Self::Unsatisfiable { continue; @@ -361,7 +368,7 @@ impl Policy { leaf_index: leaf_idx, }); } - leaf_compilations.push((OrdF64(*prob), compilation)); + leaf_compilations.push((*prob, compilation)); } if !leaf_compilations.is_empty() { Some(with_huffman_tree::(leaf_compilations)) @@ -413,14 +420,11 @@ impl Policy { Self::Trivial => None, policy => { let leaf_compilations: Vec<_> = policy - .enumerate_policy_tree(1.0) + .enumerate_policy_tree(PositiveF64::ONE) .into_iter() .filter(|x| x.1 != Arc::new(Self::Unsatisfiable)) .map(|(prob, pol)| { - ( - OrdF64(prob), - compiler::best_compilation(pol.as_ref()).unwrap(), - ) + (prob, compiler::best_compilation(pol.as_ref()).unwrap()) }) .collect(); @@ -498,19 +502,20 @@ impl Policy { /// disjunction over sub-policies output by it. The probability calculations are similar /// to [`Policy::tapleaf_probability_iter`]. #[cfg(feature = "compiler")] - fn enumerate_pol(&self, prob: f64) -> Vec<(f64, Arc)> { + fn enumerate_pol(&self, prob: PositiveF64) -> Vec<(PositiveF64, Arc)> { match self { Self::Or(subs) => { - let total_odds = subs.iter().fold(0, |acc, x| acc + x.0); - subs.iter() - .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) + let normalized_iter = PositiveF64::normalized_iter(subs.iter().map(|x| x.0.into())); + normalized_iter + .zip(subs.iter()) + .map(|(odds, (_, pol))| (prob * odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { - let total_odds = thresh.n(); + let total_odds = PositiveF64::n(thresh); thresh .iter() - .map(|pol| (prob / total_odds as f64, pol.clone())) + .map(|pol| (prob / total_odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if !thresh.is_and() => generate_combination(thresh, prob), @@ -523,25 +528,26 @@ impl Policy { /// This ensures nested `Or` branches inside `And` nodes are decomposed into /// separate sub-policies, producing IF-free leaves for Taptree-native compilation. #[cfg(feature = "compiler")] - fn enumerate_pol_native(&self, prob: f64) -> Vec<(f64, Arc)> { + fn enumerate_pol_native(&self, prob: PositiveF64) -> Vec<(PositiveF64, Arc)> { match self { Self::Or(subs) => { - let total_odds = subs.iter().fold(0, |acc, x| acc + x.0); - subs.iter() - .map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone())) + let normalized_iter = PositiveF64::normalized_iter(subs.iter().map(|x| x.0.into())); + normalized_iter + .zip(subs.iter()) + .map(|(odds, (_, pol))| (prob * odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if thresh.is_or() => { - let total_odds = thresh.n(); + let total_odds = PositiveF64::n(thresh); thresh .iter() - .map(|pol| (prob / total_odds as f64, pol.clone())) + .map(|pol| (prob / total_odds, pol.clone())) .collect::>() } Self::Thresh(ref thresh) if !thresh.is_and() => generate_combination(thresh, prob), Self::And(subs) => { for (i, sub) in subs.iter().enumerate() { - let child_expanded = sub.enumerate_pol_native(1.0); + let child_expanded = sub.enumerate_pol_native(PositiveF64::ONE); if child_expanded.len() > 1 { let other: Vec<_> = subs .iter() @@ -572,7 +578,7 @@ impl Policy { /// set](`BTreeSet`) of `(prob, policy)` (ordered by probability) to maintain the list of /// enumerated sub-policies whose disjunction is isomorphic to initial policy (*invariant*). #[cfg(feature = "compiler")] - fn enumerate_policy_tree(self, prob: f64) -> Vec<(f64, Arc)> { + fn enumerate_policy_tree(self, prob: PositiveF64) -> Vec<(PositiveF64, Arc)> { self.enumerate_leaves(prob, MAX_COMPILATION_LEAVES, Self::enumerate_pol) } @@ -584,20 +590,20 @@ impl Policy { #[allow(clippy::type_complexity)] fn enumerate_leaves( self, - prob: f64, + prob: PositiveF64, max_leaves: usize, - expand_fn: fn(&Self, f64) -> Vec<(f64, Arc)>, - ) -> Vec<(f64, Arc)> { - let mut tapleaf_prob_vec = BTreeSet::<(Reverse, Arc)>::new(); + expand_fn: fn(&Self, PositiveF64) -> Vec<(PositiveF64, Arc)>, + ) -> Vec<(PositiveF64, Arc)> { + let mut tapleaf_prob_vec = BTreeSet::<(Reverse, Arc)>::new(); // Store probability corresponding to policy in the enumerated tree. This is required since // owing to the current [policy element enumeration algorithm][`Policy::enumerate_pol`], // two passes of the algorithm might result in same sub-policy showing up. Currently, we // merge the nodes by adding up the corresponding probabilities for the same policy. - let mut pol_prob_map = BTreeMap::, OrdF64>::new(); + let mut pol_prob_map = BTreeMap::, PositiveF64>::new(); let arc_self = Arc::new(self); - tapleaf_prob_vec.insert((Reverse(OrdF64(prob)), Arc::clone(&arc_self))); - pol_prob_map.insert(Arc::clone(&arc_self), OrdF64(prob)); + tapleaf_prob_vec.insert((Reverse(prob), Arc::clone(&arc_self))); + pol_prob_map.insert(Arc::clone(&arc_self), prob); // Since we know that policy enumeration *must* result in increase in total number of nodes, // we can maintain the length of the ordered set to check if the @@ -607,21 +613,21 @@ impl Policy { // store the variables let mut enum_len = tapleaf_prob_vec.len(); - let mut ret: Vec<(f64, Arc)> = vec![]; + let mut ret: Vec<(PositiveF64, Arc)> = vec![]; // Stopping condition: When NONE of the inputs can be further enumerated. 'outer: loop { //--- FIND a plausible node --- - let mut prob: Reverse = Reverse(OrdF64(0.0)); + let mut prob: Reverse = Reverse(PositiveF64::EPSILON); let mut curr_policy: Arc = Arc::new(Self::Unsatisfiable); - let mut curr_pol_replace_vec: Vec<(f64, Arc)> = vec![]; + let mut curr_pol_replace_vec: Vec<(PositiveF64, Arc)> = vec![]; let mut no_more_enum = false; // The nodes which can't be enumerated further are directly appended to ret and removed // from the ordered set. - let mut to_del: Vec<(f64, Arc)> = vec![]; + let mut to_del: Vec<(PositiveF64, Arc)> = vec![]; 'inner: for (i, (p, pol)) in tapleaf_prob_vec.iter().enumerate() { - curr_pol_replace_vec = expand_fn(pol, p.0 .0); + curr_pol_replace_vec = expand_fn(pol, p.0); enum_len += curr_pol_replace_vec.len() - 1; // A disjunctive node should have separated this into more nodes assert!(prev_len <= enum_len); @@ -638,14 +644,14 @@ impl Policy { // Either node is enumerable, or we have // Mark all non-enumerable nodes to remove, // if not returning value in the current iteration. - to_del.push((p.0 .0, Arc::clone(pol))); + to_del.push((p.0, Arc::clone(pol))); } } // --- Sanity Checks --- if enum_len > max_leaves || no_more_enum { for (p, pol) in tapleaf_prob_vec.into_iter() { - ret.push((p.0 .0, pol)); + ret.push((p.0, pol)); } break 'outer; } @@ -659,7 +665,7 @@ impl Policy { // OPTIMIZATION - Move marked nodes into final vector for (p, pol) in to_del { - assert!(tapleaf_prob_vec.remove(&(Reverse(OrdF64(p)), pol.clone()))); + assert!(tapleaf_prob_vec.remove(&(Reverse(p), pol.clone()))); pol_prob_map.remove(&pol); ret.push((p, pol.clone())); } @@ -669,12 +675,12 @@ impl Policy { match pol_prob_map.get(&policy) { Some(prev_prob) => { assert!(tapleaf_prob_vec.remove(&(Reverse(*prev_prob), policy.clone()))); - tapleaf_prob_vec.insert((Reverse(OrdF64(prev_prob.0 + p)), policy.clone())); - pol_prob_map.insert(policy.clone(), OrdF64(prev_prob.0 + p)); + tapleaf_prob_vec.insert((Reverse(prev_prob + p), policy.clone())); + pol_prob_map.insert(policy.clone(), prev_prob + p); } None => { - tapleaf_prob_vec.insert((Reverse(OrdF64(p)), policy.clone())); - pol_prob_map.insert(policy.clone(), OrdF64(p)); + tapleaf_prob_vec.insert((Reverse(p), policy.clone())); + pol_prob_map.insert(policy.clone(), p); } } } @@ -1013,7 +1019,7 @@ impl expression::FromTree for Policy { .map_err(From::from) .map_err(Error::Parse)?; - let mut stack = Vec::<(usize, _)>::with_capacity(128); + let mut stack = Vec::<(NonZeroU32, _)>::with_capacity(128); for node in root.pre_order_iter().rev() { let allow_prob; // Before doing anything else, check if this is the inner value of a terminal. @@ -1046,10 +1052,10 @@ impl expression::FromTree for Policy { }; let frag_prob = match frag_prob { - None => 1, + None => NonZeroU32::new(1).unwrap(), // NonZeroU32::MIN available in Rust 1.70 Some(s) => expression::parse_num_nonzero(s, "fragment probability") .map_err(From::from) - .map_err(Error::Parse)? as usize, + .map_err(Error::Parse)?, }; let new = @@ -1137,8 +1143,10 @@ fn has_if_fragment(ms: &Miniscript) -> bool { /// Creates a Huffman Tree from compiled [`Miniscript`] nodes. #[cfg(feature = "compiler")] -fn with_huffman_tree(ms: Vec<(OrdF64, Miniscript)>) -> TapTree { - let mut node_weights = BinaryHeap::<(Reverse, TapTree)>::new(); +fn with_huffman_tree( + ms: Vec<(PositiveF64, Miniscript)>, +) -> TapTree { + let mut node_weights = BinaryHeap::<(Reverse, TapTree)>::new(); for (prob, script) in ms { node_weights.push((Reverse(prob), TapTree::leaf(script))); } @@ -1147,9 +1155,9 @@ fn with_huffman_tree(ms: Vec<(OrdF64, Miniscript)>) let (p1, s1) = node_weights.pop().expect("len must at least be two"); let (p2, s2) = node_weights.pop().expect("len must at least be two"); - let p = (p1.0).0 + (p2.0).0; + let p = p1.0 + p2.0; node_weights.push(( - Reverse(OrdF64(p)), + Reverse(p), TapTree::combine(s1, s2) .expect("huffman tree cannot produce depth > 128 given sane weights"), )); @@ -1172,12 +1180,12 @@ fn with_huffman_tree(ms: Vec<(OrdF64, Miniscript)>) #[cfg(feature = "compiler")] fn generate_combination( thresh: &Threshold>, 0>, - prob: f64, -) -> Vec<(f64, Arc>)> { + prob: PositiveF64, +) -> Vec<(PositiveF64, Arc>)> { debug_assert!(thresh.k() < thresh.n()); - let prob_over_n = prob / thresh.n() as f64; - let mut ret: Vec<(f64, Arc>)> = vec![]; + let prob_over_n = prob / PositiveF64::n(thresh); + let mut ret: Vec<(PositiveF64, Arc>)> = vec![]; for i in 0..thresh.n() { let thresh_less_1 = Threshold::from_iter( thresh.k(), @@ -1202,7 +1210,7 @@ pub enum TreeChildren<'a, Pk: MiniscriptKey> { /// A conjunction or threshold node's children. And(&'a [Arc>]), /// A disjunction node's children. - Or(&'a [(usize, Arc>)]), + Or(&'a [(NonZeroU32, Arc>)]), } impl<'a, Pk: MiniscriptKey> TreeLike for &'a Policy { @@ -1279,7 +1287,7 @@ mod compiler_tests { .collect(); let thresh = Threshold::new(2, policies).unwrap(); - let combinations = generate_combination(&thresh, 1.0); + let combinations = generate_combination(&thresh, PositiveF64::ONE); let comb_a: Vec> = vec![ policy_str!("pk(B)"), @@ -1306,7 +1314,7 @@ mod compiler_tests { .map(|sub_pol| { let expected_thresh = Threshold::from_iter(2, sub_pol.into_iter().map(Arc::new)).unwrap(); - (0.25, Arc::new(Policy::Thresh(expected_thresh))) + (PositiveF64::ONE_QUARTER, Arc::new(Policy::Thresh(expected_thresh))) }) .collect::>(); assert_eq!(combinations, expected_comb); diff --git a/src/primitives/mod.rs b/src/primitives/mod.rs index 919e6c8ba..d7afefdcb 100644 --- a/src/primitives/mod.rs +++ b/src/primitives/mod.rs @@ -12,5 +12,6 @@ //! should be re-exported at the crate root. pub mod absolute_locktime; +pub mod positive_f64; pub mod relative_locktime; pub mod threshold; diff --git a/src/primitives/positive_f64.rs b/src/primitives/positive_f64.rs new file mode 100644 index 000000000..0638f902d --- /dev/null +++ b/src/primitives/positive_f64.rs @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! Positive floats ("branch probabilities" for policies) +use core::iter::FusedIterator; +use core::num::NonZeroU32; +use core::{cmp, hash, ops}; + +use crate::Threshold; + +/// Ordered f64 for comparison. +#[derive(Copy, Clone, PartialEq, Debug)] +pub struct PositiveF64(f64); + +impl PositiveF64 { + /// The smallest representable value of a [`PositiveF64`]. + pub const EPSILON: Self = Self(f64::EPSILON); + + /// The constant one. + pub const ONE: Self = Self(1.0); + + /// Constant used in unit tsets + #[cfg(test)] + pub const ONE_QUARTER: Self = Self(0.25); + + /// Attempts to create a [`PositiveF64`] from an ordinary `f64`. + pub fn new(f: f64) -> Option { + // Can likely make this function const in Rust 1.83 + (f > 0.0).then_some(Self(f)) + } + + /// Given an [`Option`], if it is `Some` then add it to the value. + /// Otherwise return the unmodified value. + /// + /// Returns the sum (or original value). Does not modify in-place. + #[must_use] + pub fn conditional_add(self, other: Option) -> Self { other.map_or(self, |i| i + self) } + + /// Takes an iterator over [`PositiveF64`] and produces a new iterator where + /// each item is divided so that they all total to 1. + /// + /// On an empty iterator, returns a new empty iterator. + /// + /// Internally clones the iterator and runs it twice, so best to only use + /// this with reference-based iterators obtained with e.g. `slice.iter()` + /// rather than "owning" iterators like you'd get from `vec.into_iter()`. + pub fn normalized_iter(iter: I) -> NormalizedIterator + where + I: Iterator + Clone, + { + // Compute the sum of all the items in the iterator. Because all items in + // the iterator are positive, this will be 0 iff the iterator is empty. + let sum = iter.clone().map(|x| x.0).sum::(); + NormalizedIterator { iter, sum } + } + + /// The 'n' value of a threshold, as a [`PositiveF64`] + pub fn n(t: &Threshold) -> Self { + Self(t.n() as f64) // cast okay, worst case wil lose precision + } + + /// The ratio `k`/`n` of a threshold, as a [`PositiveF64`]. Guaranteed to be + /// in the half-open range `(0, 1]`. + pub fn k_over_n(t: &Threshold) -> Self { + Self(t.k() as f64 / t.n() as f64) // casts okay, worst case wil lose precision + } + + /// One minus the ratio `k` / `n` of a threshold, as a [`PositiveF64`]. Guaranteed + /// to be in the half-open range `[0, 1)`. + /// + /// Returns `None` if the return value would be 0, which is impermissible for the + /// [`PositiveF64`] type. + pub fn one_minus_k_over_n(t: &Threshold) -> Option { + if t.is_and() { + None + } else { + Some(Self(1.0 - t.k() as f64 / t.n() as f64)) // casts okay, worst case wil lose precision + } + } +} + +impl Eq for PositiveF64 {} + +// We could derive PartialOrd, but we can't derive Ord, and clippy wants us +// to derive both or neither. Better to be explicit. +impl PartialOrd for PositiveF64 { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } +} + +impl Ord for PositiveF64 { + fn cmp(&self, other: &Self) -> cmp::Ordering { + // will panic if given NaN + self.0.partial_cmp(&other.0).unwrap() + } +} + +/// Hash required for using OrdF64 as key for hashmap +impl hash::Hash for PositiveF64 { + fn hash(&self, state: &mut H) { self.0.to_bits().hash(state); } +} + +impl From for f64 { + fn from(value: PositiveF64) -> Self { value.0 } +} + +impl From for PositiveF64 { + fn from(value: NonZeroU32) -> Self { Self(f64::from(u32::from(value))) } +} + +macro_rules! impl_op { + ($trait:ident, $op:ident, $expr:expr) => { + impl ops::$trait for PositiveF64 { + type Output = Self; + fn $op(self, rhs: Self) -> Self::Output { Self($expr(self.0, rhs.0)) } + } + + impl ops::$trait for &PositiveF64 { + type Output = PositiveF64; + fn $op(self, rhs: Self) -> Self::Output { PositiveF64($expr(self.0, rhs.0)) } + } + + impl ops::$trait<&PositiveF64> for PositiveF64 { + type Output = Self; + fn $op(self, rhs: &PositiveF64) -> Self::Output { Self($expr(self.0, rhs.0)) } + } + + impl ops::$trait for &PositiveF64 { + type Output = PositiveF64; + fn $op(self, rhs: PositiveF64) -> Self::Output { PositiveF64($expr(self.0, rhs.0)) } + } + }; +} + +impl_op!(Add, add, f64::add); +impl_op!(Mul, mul, f64::mul); +impl_op!(Div, div, f64::div); + +pub struct NormalizedIterator { + iter: I, + /// Sum must be nonnegative, and may only be zero if `iter` is empty. + sum: f64, +} + +impl Iterator for NormalizedIterator +where + I: Iterator, +{ + type Item = I::Item; + fn next(&mut self) -> Option { + self.iter.next().map(|x| PositiveF64(x.0 / self.sum)) + } + + fn size_hint(&self) -> (usize, Option) { self.iter.size_hint() } +} + +impl DoubleEndedIterator for NormalizedIterator +where + I: Iterator + DoubleEndedIterator, +{ + fn next_back(&mut self) -> Option { + self.iter.next_back().map(|x| PositiveF64(x.0 / self.sum)) + } +} + +impl ExactSizeIterator for NormalizedIterator where + I: Iterator + ExactSizeIterator +{ +} + +impl FusedIterator for NormalizedIterator where + I: Iterator + FusedIterator +{ +}